PyTorch深度学习总结--12

PyTorch深度学习总结--12,第1张

PyTorch深度学习总结--12 PyTorch深度学习总结–12_Basic_RNN

12_Basic_RNN

任务介绍

通过RnnCell学习将hello转化为ohlol

使用embedding进行压缩

如何使用embedding压缩,需要将字母转换为对应的数字。
这里采用字母下标当做该字母表示的数字

转换为embedding向量

输入形式

输入x是每个字母
输出o是x转换为对应的字母

实现代码

python 语言,PyTorch实现

import torch

idx2char = ['e', 'h', 'l', 'o']
x_data = [[1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1]]
y_data = [3, 1, 2, 3, 2, 1, 0, 2, 2, 3, 3]

batch_size = 1  # 表示hidden的维度
input_size = 4  # 表示母用的种类数
hidden_size = 10  # 表示hidden这个向量的长度
num_layers = 2  # 表示使用了几层RNN结构
embedding_size = 10  # 表示每一个字母用长度为10的向量表示,要作为RNN的输入
num_class = 4  # 表示单词的种类数
inputs = torch.LongTensor(x_data)
labels = torch.LongTensor(y_data)


class Model(torch.nn.Module):
    def __init__(self, input_size, hidden_size, batch_size, num_layers, embedding_size, num_class):
        super(Model, self).__init__()
        self.batch_size = batch_size
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_class = num_class
        self.embedding_size = embedding_size

        self.linear = torch.nn.Linear(hidden_size, num_class)
        self.emb = torch.nn.Embedding(input_size, embedding_size)
        self.rnn = torch.nn.RNN(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers,
                                batch_first=True)

    def forward(self, x):
        hidden = self.init_hidden()
        x = self.emb(x)
        x, hidden = self.rnn(x, hidden)
        x = self.linear(x)
        return x.view(-1, self.num_class)

    def init_hidden(self):
        return torch.zeros(self.num_layers, self.batch_size, self.hidden_size)


net = Model(input_size, hidden_size, batch_size, num_layers, embedding_size, num_class)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.05)

for epoch in range(50):
    optimizer.zero_grad()
    out = net(inputs)
    loss = criterion(out, labels)
    loss.backward()
    optimizer.step()
    _, idx = out.max(dim=1)
    idx = idx.data.numpy()
    print('Predicted: ', ''.join([idx2char[x] for x in idx]), end='')
    print(', Epoch [%d/50] loss = %.3f' % (epoch + 1, loss.item()))

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5710574.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-17
下一篇 2022-12-17

发表评论

登录后才能评论

评论列表(0条)

保存