- 以往的代码,都是随便看看就过去了,没有这样较真过,以至于看了很久的深度学习和Python,都没有能够形成编程能力;
- 这次算是废寝忘食的深入进去了,踏实地把每一个代码都理解透,包括其中的数学原理(目前涉及的还很浅)和代码语句实现的功能;也是得益于疫情封闭在寝室,才有如此踏实的心情和宽松的时间,最重要的是周边的环境,没有干扰。
#!usr/bin/env python
# -*- coding:utf-8 _*-
"""
@author: 24_nemo
@file: 04_BackPropagation_handType.py
@time: 2022/04/08
@desc:
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = torch.Tensor([1.0])
w.requires_grad = True
def forward(x):
return x * w
def loss(x, y):
y_hat = forward(x)
return (y_hat - y) ** 2
def gradient(x, y):
return 2 * x * (x * w - y)
print('Predict (before training)', 4, forward(4))
epoch_list = []
loss_list = []
for epoch in np.arange(100):
loss_val = 0
for x, y in zip(x_data, y_data):
loss_val = loss(x, y)
loss_val.backward()
print('\tgrad:', x, y, w.grad.item())
w.data -= 0.01 * w.grad.data
epoch_list.append(epoch)
loss_list.append(loss_val.item()) # 往列表里添加的内容,不可是tensor类型,或者item(),或者.detach().numpy()取出来
w.grad.data.zero_()
print('process:', epoch, loss_val.item())
print('Predict(after training)', 4, forward(4).item())
plt.plot(epoch_list, loss_list)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()
运行结果:
默写出的错误:
经过修改之后的代码:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
# w = torch.Tensor([1.0])
# w.requires_grad = True
w = torch.Tensor([1.0])
w.requires_grad = True
def forward(x):
return x * w
def loss(x, y):
y_hat = forward(x)
return (y_hat - y) ** 2
def gradient(x, y):
return 2 * x * (w * x - y)
loss_list = []
epoch_list = []
for epoch in np.arange(100):
loss_val = 0
for x, y in zip(x_data, y_data):
loss_val = loss(x, y)
loss_val.backward()
print('\tgrad:', x, y, w.grad.item())
w.data -= 0.01 * w.grad.data
epoch_list.append(epoch)
loss_list.append(loss_val.item())
w.grad.data.zero_()
print('process:', epoch, loss_val.item())
print('Predict(after training)', 4, forward(4).item())
plt.plot(epoch_list, loss_list)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()
运行结果:
小结:我还是没有完全自己写明白
data
grad
item()
detach.numpy
这几个东西,还是没搞明白;
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)