TensorFlow:记住下一批的LSTM状态(有状态LSTM)

TensorFlow:记住下一批的LSTM状态(有状态LSTM),第1张

TensorFlow:记住下一批的LSTM状态(有状态LSTM)

我发现在占位符中保存所有图层的整个状态是最容易的。

init_state = np.zeros((num_layers, 2, batch_size, state_size))...state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])

然后解压缩它并创建一个LSTMStateTuples元组,然后再使用本机tensorflow RNN Api。

l = tf.unpack(state_placeholder, axis=0)rnn_tuple_state = tuple([tf.nn.rnn_cell.LSTMStateTuple(l[idx][0], l[idx][1]) for idx in range(num_layers)])

RNN传递API:

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)cell = tf.nn.rnn_cell.MultiRNNCell([cell]*num_layers, state_is_tuple=True)outputs, state = tf.nn.dynamic_rnn(cell, x_input_batch, initial_state=rnn_tuple_state)

state
-那么变量将被feeded下一批作为占位符。



欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5623504.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-15
下一篇 2022-12-15

发表评论

登录后才能评论

评论列表(0条)

保存