Pytorch简单案例:y=wx+b参数训练

Pytorch简单案例:y=wx+b参数训练,第1张

Pytorch简单案例:y=wx+b参数训练

简单练习一下Pytorch:目标方程是y=2x+3,使用六个数据样本进行100次迭代

// An highlighted block
import torch
import numpy
import matplotlib.pyplot as plt

x_data=[1.0,2.0,3.0,4.0,5.0,6.0]
y_data=[5.0,7.0,9.0,11.0,13.0,15.0]
w=torch.tensor([1.0],requires_grad=True)
b=torch.tensor([1.0],requires_grad=True)
w_list=[]
b_list=[]

def pred(x):
    return x*w+b

def loss(x,y):
    l=(pred(x)-y)**2
    return l

for epoch in range(100):
    w_list.append(w.data.item())
    b_list.append(b.data.item())
    for xs,ys in zip(x_data,y_data):
        lo=loss(xs,ys)
        lo.backward()
        w.data=w.data-0.01*w.grad.item()
        b.data=b.data-0.01*b.grad.item()
        w.grad.data.zero_()
        b.grad.data.zero_()
    print("epoch:",epoch,"loss:",lo.item())

print("预测值为:",pred(5).item())

plt.plot(numpy.arange(100),w_list,color="blue",label='parameter_w')
plt.plot(numpy.arange(100),b_list,color="red",label='parameter_b')
plt.show()

测试结果:蓝线是w,最终逼近2,红线是b,最终逼近3

上面的模型是手动进行梯度下降,下面这个是自动实现的梯度下降

import torch
import numpy
import matplotlib.pyplot as plt

x_data=torch.tensor([[1.0],[2.0],[3.0],[4.0]])
y_data=torch.tensor([[6.0],[9.0],[12.0],[15.0]])
loss_list=[]
#创建模型
class LinearModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear=torch.nn.Linear(1,1)

    def forward(self,x):
        y_pred=self.linear(x)
        return y_pred
#损失函数和优化器
model=LinearModel()
criterion=torch.nn.MSELoss(reduction='mean')
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
#模型迭代,3000次
for epoch in range(3000):
    y_pre=model(x_data)
    loss=criterion(y_pre,y_data)
    loss_list.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

epoch_list=numpy.arange(3000)
print(model.linear.weight.item(),model.linear.bias.item())#输出权值和偏差
test_set=torch.tensor([[8.0]])
print(model(test_set).item())
#可视化训练过程
plt.plot(epoch_list,loss_list,label='loss',color='blue')
plt.show()

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/4654595.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-11-06
下一篇 2022-11-06

发表评论

登录后才能评论

评论列表(0条)

保存