Tensorflow LSTM中的c_state和m_state是什么?

Tensorflow LSTM中的c_state和m_state是什么?,第1张

Tensorflow LSTM中的c_state和m_state是什么?

我偶然发现了同样的问题,这就是我的理解方式!简约的LSTM示例:

import tensorflow as tfsample_input = tf.constant([[1,2,3]],dtype=tf.float32)LSTM_CELL_SIZE = 2lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_CELL_SIZE, state_is_tuple=True)state = (tf.zeros([1,LSTM_CELL_SIZE]),)*2output, state_new = lstm_cell(sample_input, state)init_op = tf.global_variables_initializer()sess = tf.Session()sess.run(init_op)print sess.run(output)

请注意,

state_is_tuple=True
因此在传递
state
给this时
cell
,它必须采用
tuple
表格形式。
c_state
并且
m_state
可能是“内存状态”和“单元状态”,尽管老实说我不确定,因为这些术语仅在文档中提及。在代码和文件中,关于
LSTM
-字母
h
c
通常用于表示“输出值”和“单元状态”。
http://colah.github.io/posts/2015-08-Understanding-
LSTMs/

这些张量表示单元的组合内部状态,应该一起传递。这样做的旧方法是简单地将它们连接起来,而新方法是使用元组。

旧方法:

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_CELL_SIZE, state_is_tuple=False)state = tf.zeros([1,LSTM_CELL_SIZE*2])output, state_new = lstm_cell(sample_input, state)

新方法:

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_CELL_SIZE, state_is_tuple=True)state = (tf.zeros([1,LSTM_CELL_SIZE]),)*2output, state_new = lstm_cell(sample_input, state)

因此,基本上我们所做的一切,都

state
从长度的1张量更改为长度的
4
2张量
2
。内容保持不变。
[0,0,0,0]
成为
([0,0],[0,0])
。(这应该使其速度更快)



欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5508117.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-13
下一篇 2022-12-13

发表评论

登录后才能评论

评论列表(0条)

保存