python – 如何在Tensorflow中加载预训练的LSTM模型权重

python – 如何在Tensorflow中加载预训练的LSTM模型权重,第1张

概述我想在Tensorflow中实现具有预训练权重的LSTM模型.这些重量可能来自Caffee或Torch. 我发现文件rnn_cell.py中有LSTM单元,例如rnn_cell.BasicLSTMCell和rnn_cell.MultiRNNCell.但是如何为这些LSTM单元加载预训练的权重. 这是一个加载预先训练的Caffe模型的解决方案.见 full code here,在 this thre 我想在Tensorflow中实现具有预训练权重的LSTM模型.这些重量可能来自Caffee或Torch.
我发现文件rnn_cell.py中有LSTM单元,例如rnn_cell.BasicLSTMCell和rnn_cell.MultiRNNCell.但是如何为这些LSTM单元加载预训练的权重.解决方法 这是一个加载预先训练的Caffe模型的解决方案.见 full code here,在 this thread的讨论中引用.

net_caffe = caffe.Net(prototxt,caffemodel,caffe.TEST)caffe_layers = {}for i,layer in enumerate(net_caffe.layers):    layer_name = net_caffe._layer_names[i]    caffe_layers[layer_name] = layerdef caffe_weights(layer_name):    layer = caffe_layers[layer_name]    return layer.blobs[0].datadef caffe_bias(layer_name):    layer = caffe_layers[layer_name]    return layer.blobs[1].data#tensorflow uses [filter_height,filter_wIDth,in_channels,out_channels] 2-3-1-0 #caffe uses [out_channels,filter_height,filter_wIDth] 0-1-2-3def caffe2tf_filter(name):    f = caffe_weights(name)    return f.transpose((2,3,1,0))class ModelFromCaffe():    def get_conv_filter(self,name):        w = caffe2tf_filter(name)        return tf.constant(w,dtype=tf.float32,name="filter")    def get_bias(self,name):        b = caffe_bias(name)        return tf.constant(b,name="bias")    def get_fc_weight(self,name):        cw = caffe_weights(name)        if name == "fc6":            assert cw.shape == (4096,25088)            cw = cw.reshape((4096,512,7,7))             cw = cw.transpose((2,0))            cw = cw.reshape(25088,4096)        else:            cw = cw.transpose((1,0))        return tf.constant(cw,name="weight")images = tf.placeholder("float",[None,224,3],name="images")m = ModelFromCaffe()with tf.Session() as sess:  sess.run(tf.initialize_all_variables())  batch = cat.reshape((1,3))  out = sess.run([m.prob,m.relu1_1,m.pool5,m.fc6],Feed_dict={ images: batch })...
总结

以上是内存溢出为你收集整理的python – 如何在Tensorflow中加载预训练的LSTM模型权重全部内容,希望文章能够帮你解决python – 如何在Tensorflow中加载预训练的LSTM模型权重所遇到的程序开发问题。

如果觉得内存溢出网站内容还不错,欢迎将内存溢出网站推荐给程序员好友。

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/1196153.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-06-03
下一篇 2022-06-03

发表评论

登录后才能评论

评论列表(0条)

保存