# 使用numpy实现机器学习 import numpy as np # %matplotlib inline #在jupyter中可使用,但在pycharm中用不上 from matplotlib import pyplot as plt np.random.seed(100) x = np.linspace(-1, 1, 100).reshape(100, 1) y = 3*np.power(x, 2) + 2 + 0.2*np.random.rand(x.size).reshape(100, 1) # print(x) # print(y) # 画图 plt.scatter(x, y) plt.show() # 随机初始化参数 w1 = np.random.rand(1, 1) b1 = np.random.rand(1, 1) # 训练模型 lr =0.001 # 学习率 for i in range(800): # 前向传播 y_pred = np.power(x,2)*w1 + b1 # 定义损失函数 loss = 0.5 * (y_pred - y) ** 2 loss = loss.sum() #计算梯度 grad_w=np.sum((y_pred - y)*np.power(x,2)) grad_b=np.sum((y_pred - y)) #使用梯度下降法,是loss最小 w1 -= lr * grad_w b1 -= lr * grad_b # 可视化结果 plt.plot(x, y_pred, 'r-', label='predict') plt.scatter(x, y, color='blue', marker='o', label='true') # true data plt.xlim(-1, 1) plt.ylim(2, 6) plt.legend() plt.show() print(w1, b1)
运行结果
[[2.98927619]] [[2.09818307]]
结果图
使用Tensor及Autograd实现机器学习import torch from matplotlib import pyplot as plt torch.manual_seed(100) dtype = torch.float #生成x坐标数据,x为tenor,形状为100x1 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) #生成y坐标数据,y为tenor,形状为100x1,另加上一些噪音 y = 3*x.pow(2) + 2 + 0.2*torch.rand(x.size()) # 画图,把tensor数据转换为numpy数据 plt.scatter(x.numpy(), y.numpy()) plt.show() # 随机初始化参数,参数w,b为需要学习的,故需requires_grad=True w = torch.randn(1, 1, dtype=dtype, requires_grad=True) b = torch.zeros(1, 1, dtype=dtype, requires_grad=True) # 训练模型 lr = 0.001 # 学习率 for ii in range(800): # forward:计算loss y_pred = x.pow(2).mm(w) + b loss = 0.5 * (y_pred - y) ** 2 loss = loss.sum() # backward:自动计算梯度 loss.backward() # 手动更新参数,需要用torch.no_grad()更新参数 with torch.no_grad(): w -= lr * w.grad b -= lr * b.grad # 梯度清零 w.grad.zero_() b.grad.zero_() #可视化训练结果 plt.plot(x.numpy(), y_pred.detach().numpy(), 'r-', label='predict') # predict plt.scatter(x.numpy(), y.numpy(), color='blue', marker='o', label='true') # true data plt.xlim(-1, 1) plt.ylim(2, 6) plt.legend() plt.show() print(w, b)
运行结果
tensor([[2.9645]], requires_grad=True) tensor([[2.1146]], requires_grad=True)
结果图
总结使用numpy和Tensor,它们所实现的结果差不多。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)