我偶然发现了同样的问题,这就是我的理解方式!简约的LSTM示例:
import tensorflow as tfsample_input = tf.constant([[1,2,3]],dtype=tf.float32)LSTM_CELL_SIZE = 2lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_CELL_SIZE, state_is_tuple=True)state = (tf.zeros([1,LSTM_CELL_SIZE]),)*2output, state_new = lstm_cell(sample_input, state)init_op = tf.global_variables_initializer()sess = tf.Session()sess.run(init_op)print sess.run(output)
请注意,
state_is_tuple=True因此在传递
state给this时
cell,它必须采用
tuple表格形式。
c_state并且
m_state可能是“内存状态”和“单元状态”,尽管老实说我不确定,因为这些术语仅在文档中提及。在代码和文件中,关于
LSTM-字母
h和
c通常用于表示“输出值”和“单元状态”。
http://colah.github.io/posts/2015-08-Understanding-
LSTMs/
这些张量表示单元的组合内部状态,应该一起传递。这样做的旧方法是简单地将它们连接起来,而新方法是使用元组。
旧方法:
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_CELL_SIZE, state_is_tuple=False)state = tf.zeros([1,LSTM_CELL_SIZE*2])output, state_new = lstm_cell(sample_input, state)
新方法:
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_CELL_SIZE, state_is_tuple=True)state = (tf.zeros([1,LSTM_CELL_SIZE]),)*2output, state_new = lstm_cell(sample_input, state)
因此,基本上我们所做的一切,都
state从长度的1张量更改为长度的
42张量
2。内容保持不变。
[0,0,0,0]成为
([0,0],[0,0])。(这应该使其速度更快)
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)