【参考:详解LSTM - 知乎】
【参考:YJango的循环神经网络——实现LSTM - 知乎】 强烈建议阅读
- 输入门i_t,
- 遗忘门 f_t,
- 输出门o_t
- cell门g_t
小圆圈代表点乘,对应元素相乘
【参考:30、PyTorch LSTM和LSTMP的原理及其手写复现_哔哩哔哩_bilibili】
单层单向【参考:30 - LSTM,LSTMP手撸代码_取个名字真难呐的博客-CSDN博客】
import torch
from torch import nn
torch.manual_seed(0) # 设置随机种子,随机函数生成的结果会相同
batch_size = 2 # 批次大小
seq_len = 3 # 输入序列长度
input_size = 4 # 输入数据特征大小
hidden_size = 5 # 隐藏层特征大小
num_layers = 1 # 层数
# random init the input
input_one = torch.randn(batch_size, seq_len, input_size) # bs,seq_len,input_size 随机初始化一个特征序列
# random init the init hidden state
h0 = torch.zeros(batch_size, hidden_size) # 初始隐含状态h_0 (bs,hidden_size) pytorch默认初始化全0
# 本来应该是(1,batch_size,hidden_size) 这里为了简便传递参数和下面的计算 因为很多时候传递的参数都是二维
c0 = torch.zeros(batch_size, hidden_size) # 初始值 不参与训练
# define the RNN layers
lstm_layer = nn.LSTM(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True)
output_api, (h_n_api, c_n_api) = lstm_layer(input_one, (h0.unsqueeze(0), c0.unsqueeze(0)))
# h_0 对于输入的是批次数据时维度为 (D∗num_layers,bs,hidden_size) 看官方参数
# h_prev.unsqueeze(0) h_prev:(bs,hidden_size)->(1,bs,hidden_size) 这里D为1,num_layers也是1
print(f"output.shape={output_api.shape}")
print(f"h_n.shape={h_n_api.shape}")
print(f"c_n.shape={c_n_api.shape}")
print(f"output={output_api}")
print(f"h_n={h_n_api}")
print(f"c_n={c_n_api}")
# 获取模型的参数
for k, v in lstm_layer.named_parameters():
print(k, v.shape)
weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 5])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])
"""
weight_ih_l0 是 W_ii,W_if,W_ig,W_io四个拼接起来的 (4*5,4)
weight_hh_l0,bias_ih_l0,bias_hh_l0同理
"""
# 单向单层custom_lstm_function
def custom_lstm_function(input, init_state, w_ih, w_hh, b_ih, b_hh):
"""
:param input:
:param init_state:
:param w_ih: W_ii,W_if,W_ig,W_io四个拼接起来的 (4*hidden_size,input_size)
:param w_hh: W_hi,W_hf,W_hg,W_ho四个拼接起来的 (4*hidden_size,hidden_size)
:param b_ih: 四个拼接起来的
:param b_hh: 四个拼接起来的
:return:
"""
# ho.shape=torch.Size([batch_size,hidden_size])
# c0.shape=torch.Size([batch_size,hidden_size])
h0, c0 = init_state
h_prev = h0
c_prev = c0
batch_size, seq_len, input_size = input.shape
hidden_size = w_ih.shape[0] // 4
output = torch.zeros((batch_size, seq_len, hidden_size))
# h_prev.shape = torch.Size([1,batch_size,hidden_size]) -> ([batch_size,hidden_size,1])
# c_prev.shape = torch.Size([1,batch_size,hidden_size]) -> ([batch_size,hidden_size,1])
# batch_w_ih.shape = (4*hidden_size,input_size) -> (batch_size,4*hidden_size,input_size)
batch_w_ih = w_ih.unsqueeze(0).tile([batch_size, 1, 1])
# batch_w_hh.shape=(4*hidden_size,hidden_size)->(batch_size,4*hidden_size,hidden_size)
batch_w_hh = w_hh.unsqueeze(0).tile([batch_size, 1, 1])
for t in range(seq_len):
# input.shape = torch.Size([batch_size,seq_len,input_size])
# x.shape = torch.Size([batch_size,input_size]) -> (batch_size,input_size,1)
x = input[:, t, :].unsqueeze(-1)
# w_ih.shape=torch.Size([4*hidden_size,input_size])
# w_ih_times_x.shape=torch.Size([batch_size,4*hidden_size,1]) -> ([batch_size,4*hidden_size])
w_ih_times_x = torch.bmm(batch_w_ih, x).squeeze(-1)
# h_prev = (batch_size,hidden_size) -> (batch_size,hidden_size,1)
# batch_w_hh.shape=torch.Size([batch_size,4*hidden_size,hidden_size])
# w_hh_times_prev_h.shape = torch.Size([batch_size,4*hidden_size,1]) -> ([batch_size,4*hidden_size])
w_hh_times_prev_h = torch.bmm(batch_w_hh, h_prev.unsqueeze(-1)).squeeze(-1)
# define the 输入门i_t,遗忘门 f_t,cell门g_t,输出门o_t,c_t,h_t
# i_t.shape = torch.Size([batch_size,hidden_size])
i_t = torch.sigmoid(w_ih_times_x[:, :hidden_size]
+ b_ih[:hidden_size]
+ w_hh_times_prev_h[:, :hidden_size]
+ b_hh[:hidden_size])
# f_t.shape = torch.Size([batch_size,hidden_size])
f_t = torch.sigmoid(w_ih_times_x[:, hidden_size:2 * hidden_size]
+ b_ih[hidden_size:2 * hidden_size]
+ w_hh_times_prev_h[:, hidden_size:2 * hidden_size]
+ b_hh[hidden_size:2 * hidden_size])
# (batch_size,hidden_size)
g_t = torch.tanh(w_ih_times_x[:, 2 * hidden_size:3 * hidden_size]
+ b_ih[2 * hidden_size:3 * hidden_size]
+ w_hh_times_prev_h[:, 2 * hidden_size:3 * hidden_size]
+ b_hh[2 * hidden_size:3 * hidden_size])
# (batch_size,hidden_size)
o_t = torch.sigmoid(w_ih_times_x[:, 3 * hidden_size:]
+ b_ih[3 * hidden_size:]
+ w_hh_times_prev_h[:, 3 * hidden_size:]
+ b_hh[3 * hidden_size:])
# c_prev.shape = ([batch_size,hidden_size,1]) ->(batch_size,hidden_size)
c_prev = f_t * c_prev.squeeze(-1) + i_t * g_t # c_t
# h_prev.shape = # (batch_size,hidden_size)
h_prev = o_t * torch.tanh(c_prev) # h_t
# output[:,t,:].shape = torch.Size([batch_size,hidden_size])
output[:, t, :] = h_prev
return output, (h_prev.unsqueeze(0), c_prev.unsqueeze(0))
cu_input = input_one
cu_init_state = (h0,c0)
cu_weight_ih_l0 = lstm_layer.weight_ih_l0
cu_weight_hh_l0 = lstm_layer.weight_hh_l0
cu_bias_ih_l0 = lstm_layer.bias_ih_l0
cu_bias_hh_l0 = lstm_layer.bias_hh_l0
custom_output,(custom_hn,custom_cn) = custom_lstm_function(cu_input,cu_init_state,cu_weight_ih_l0,cu_weight_hh_l0,cu_bias_ih_l0,cu_bias_hh_l0)
print(f"custom_output.shape={custom_output.shape}")
print(f"custom_hn.shape={custom_hn.shape}")
print(f"custom_cn.shape={custom_cn.shape}")
print(f"custom_output={custom_output}")
print(f"custom_hn={custom_hn}")
print(f"custom_cn={custom_cn}")
print("output is equal ?")
print(torch.isclose(custom_output, output_api))
print("h_n is equal ?")
print(torch.isclose(custom_hn, h_n_api))
print("c_n is equal ?")
print(torch.isclose(custom_cn, c_n_api))
...
output is equal ?
tensor([[[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True]],
[[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True]]])
h_n is equal ?
tensor([[[True, True, True, True, True],
[True, True, True, True, True]]])
c_n is equal ?
tensor([[[True, True, True, True, True],
[True, True, True, True, True]]])
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)