我发现在占位符中保存所有图层的整个状态是最容易的。
init_state = np.zeros((num_layers, 2, batch_size, state_size))...state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
然后解压缩它并创建一个LSTMStateTuples元组,然后再使用本机tensorflow RNN Api。
l = tf.unpack(state_placeholder, axis=0)rnn_tuple_state = tuple([tf.nn.rnn_cell.LSTMStateTuple(l[idx][0], l[idx][1]) for idx in range(num_layers)])
RNN传递API:
cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)cell = tf.nn.rnn_cell.MultiRNNCell([cell]*num_layers, state_is_tuple=True)outputs, state = tf.nn.dynamic_rnn(cell, x_input_batch, initial_state=rnn_tuple_state)
本
state-那么变量将被feeded下一批作为占位符。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)