- 写在前面
- 数学原理
- PyTorch代码实现
- 梯度下降的实现
- 随机梯度下降的实现
- 思考题
- 参考文章
写在前面
本文主要介绍深度学习算法中的梯度下降法(Gradient Descent)和随机梯度下降法(Stochastic Gradient Descent)的数学原理及PyTorch的代码实现1 2。
数学原理
梯度下降法:每次迭代都朝梯度下降最快的方向走3,其特点是效率高;
随机梯度下降法:每次迭代中随机选择其中的一个样本来求梯度值,进行权重更新,而不是累加求平均值。
两者的数学公式区别如下(以线性函数为例):
PyTorch代码实现 梯度下降的实现
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # @Time : 2021/11/24 15:34 # @Author : William Baker # @FileName: gd.py # @Software: PyCharm # @Blog : https://blog.csdn.net/weixin_43051346 # Gradent Descent 梯度下降 import matplotlib.pyplot as plt # 准备数据集 x_data = [1.0, 2.0, 3.0] y_data = [2.0, 4.0, 6.0] # 初始化权重值 w w = 1.0 # w = 5.0 # 学习率 learning rate lr = 0.01 # 定义模型-线性模型 y=w*x def forward(x): return w*x # 定义代价函数cost def cost(xs, ys): cost = 0 for x, y in zip(xs, ys): y_pred = forward(x) # 预测值 y_hat cost += (y_pred - y) ** 2 # MSE 均方误差 return cost/len(xs) # 定义梯度 def gradient(xs, ys): grad = 0 for x, y in zip(xs, ys): y_pred = forward(x) grad += (y_pred - y) * 2 * x # changed here return grad/len(xs) epoch_list = [] cost_list = [] print('训练前的输入值x:{}, 训练前的预测值:{}n'.format(4.0, forward(4.0))) print("***************************开始训练***************************") # 开始训练 for epoch in range(100): cost_val = cost(x_data, y_data) # 预测的loss值 grad_val = gradient(x_data, y_data) # 预测的gradient w -= lr * grad_val # w = w - lr * gradient(w) 梯度下降的核心所在 # w -= 0.1 * w # print('Epoch:{}, w={}, loss={}'.format(epoch, w, cost_val)) print('Epoch:{}, w={}, loss={}, grad={}'.format(epoch, w, cost_val, grad_val)) epoch_list.append(epoch) cost_list.append(cost_val) print("***************************训练结束***************************n") print('训练后的输入值x:{}, 训练后的预测值:{}'.format(4.0, forward(4))) # 绘图 plt.plot(epoch_list, cost_list) plt.ylabel('Cost') plt.xlabel('Epoch') plt.show()
输出结果如下:
训练前的输入值x:4.0, 训练前的预测值:4.0 ***************************开始训练*************************** Epoch:0, w=1.0933333333333333, loss=4.666666666666667, grad=-9.333333333333334 Epoch:1, w=1.1779555555555554, loss=3.8362074074074086, grad=-8.462222222222222 Epoch:2, w=1.2546797037037036, loss=3.1535329869958857, grad=-7.6724148148148155 Epoch:3, w=1.3242429313580246, loss=2.592344272332262, grad=-6.956322765432099 Epoch:4, w=1.3873135910979424, loss=2.1310222071581117, grad=-6.30706597399177 Epoch:5, w=1.4444976559288012, loss=1.7517949663820642, grad=-5.718406483085872 Epoch:6, w=1.4963445413754464, loss=1.440053319920117, grad=-5.184688544664522 Epoch:7, w=1.5433523841804047, loss=1.1837878313441108, grad=-4.700784280495834 Epoch:8, w=1.5859728283235668, loss=0.9731262101573632, grad=-4.262044414316223 Epoch:9, w=1.6246153643467005, loss=0.7999529948031382, grad=-3.864253602313377 Epoch:10, w=1.659651263674342, loss=0.6575969151946154, grad=-3.503589932764129 Epoch:11, w=1.6914171457314033, loss=0.5405738908195378, grad=-3.1765882057061425 Epoch:12, w=1.7202182121298057, loss=0.44437576375991855, grad=-2.8801066398402355 Epoch:13, w=1.7463311789976905, loss=0.365296627844598, grad=-2.6112966867884806 Epoch:14, w=1.7700069356245727, loss=0.3002900634939416, grad=-2.3675756626882225 Epoch:15, w=1.7914729549662791, loss=0.2468517784170642, grad=-2.1466019341706555 Epoch:16, w=1.8109354791694263, loss=0.2029231330489788, grad=-1.9462524203147282 Epoch:17, w=1.8285815011136133, loss=0.16681183417217407, grad=-1.764602194418688 Epoch:18, w=1.8445805610096762, loss=0.1371267415488235, grad=-1.5999059896062764 Epoch:19, w=1.8590863753154396, loss=0.11272427607497944, grad=-1.4505814305763567 Epoch:20, w=1.872238313619332, loss=0.09266436490145864, grad=-1.31519383038923 Epoch:21, w=1.8841627376815275, loss=0.07617422636521683, grad=-1.1924424062195684 Epoch:22, w=1.8949742154979183, loss=0.06261859959338009, grad=-1.081147781639076 Epoch:23, w=1.904776622051446, loss=0.051475271914629306, grad=-0.9802406553527626 Epoch:24, w=1.9136641373266443, loss=0.04231496130368814, grad=-0.888751527519838 Epoch:25, w=1.9217221511761575, loss=0.03478477885657844, grad=-0.8058013849513194 Epoch:26, w=1.9290280837330496, loss=0.02859463421027894, grad=-0.7305932556891969 Epoch:27, w=1.9356521292512983, loss=0.023506060193480772, grad=-0.6624045518248707 Epoch:28, w=1.9416579305211772, loss=0.01932302619282764, grad=-0.6005801269878838 Epoch:29, w=1.9471031903392007, loss=0.015884386331668398, grad=-0.5445259818023471 Epoch:30, w=1.952040225907542, loss=0.01305767153735723, grad=-0.49370355683412726 Epoch:31, w=1.9565164714895047, loss=0.010733986344664803, grad=-0.44762455819627417 Epoch:32, w=1.9605749341504843, loss=0.008823813841374291, grad=-0.40584626609795665 Epoch:33, w=1.9642546069631057, loss=0.007253567147113681, grad=-0.3679672812621462 Epoch:34, w=1.9675908436465492, loss=0.005962754575689583, grad=-0.33362366834434704 Epoch:35, w=1.970615698239538, loss=0.004901649272531298, grad=-0.3024854592988742 Epoch:36, w=1.9733582330705144, loss=0.004029373553099482, grad=-0.27425348309764513 Epoch:37, w=1.975844797983933, loss=0.0033123241439168096, grad=-0.24865649134186527 Epoch:38, w=1.9780992835054327, loss=0.0027228776607060357, grad=-0.22544855214995874 Epoch:39, w=1.980143350378259, loss=0.002238326453885249, grad=-0.20440668728262779 Epoch:40, w=1.9819966376762883, loss=0.001840003826269386, grad=-0.185328729802915 Epoch:41, w=1.983676951493168, loss=0.0015125649231412608, grad=-0.1680313816879758 Epoch:42, w=1.9852004360204722, loss=0.0012433955919298103, grad=-0.1523484527304313 Epoch:43, w=1.9865817286585614, loss=0.0010221264385926248, grad=-0.13812926380892523 Epoch:44, w=1.987834100650429, loss=0.0008402333603648631, grad=-0.12523719918675966 Epoch:45, w=1.9889695845897222, loss=0.0006907091659248264, grad=-0.11354839392932907 Epoch:46, w=1.9899990900280147, loss=0.0005677936325753796, grad=-0.10295054382926017 Epoch:47, w=1.9909325082920666, loss=0.0004667516012495216, grad=-0.09334182640519595 Epoch:48, w=1.9917788075181404, loss=0.000383690560742734, grad=-0.08462992260737945 Epoch:49, w=1.9925461188164473, loss=0.00031541069384432885, grad=-0.07673112983068957 Epoch:50, w=1.9932418143935788, loss=0.0002592816085930997, grad=-0.06956955771315876 Epoch:51, w=1.9938725783835114, loss=0.0002131410058905752, grad=-0.06307639899326374 Epoch:52, w=1.994444471067717, loss=0.00017521137977565514, grad=-0.0571892684205603 Epoch:53, w=1.9949629871013967, loss=0.0001440315413480261, grad=-0.05185160336797523 Epoch:54, w=1.9954331083052663, loss=0.0001184003283899171, grad=-0.0470121203869646 Epoch:55, w=1.9958593515301082, loss=9.733033217332803e-05, grad=-0.042624322484180986 Epoch:56, w=1.9962458120539648, loss=8.000985883901657e-05, grad=-0.0386460523856574 Epoch:57, w=1.9965962029289281, loss=6.57716599593935e-05, grad=-0.035039087496328225 Epoch:58, w=1.9969138906555615, loss=5.406722767150764e-05, grad=-0.03176877266333733 Epoch:59, w=1.997201927527709, loss=4.444566413387458e-05, grad=-0.02880368721475879 Epoch:60, w=1.9974630809584561, loss=3.65363112808981e-05, grad=-0.026115343074715636 Epoch:61, w=1.9976998600690001, loss=3.0034471708953996e-05, grad=-0.02367791105440838 Epoch:62, w=1.9979145397958935, loss=2.4689670610172655e-05, grad=-0.02146797268933165 Epoch:63, w=1.9981091827482769, loss=2.0296006560253656e-05, grad=-0.01946429523832638 Epoch:64, w=1.9982856590251044, loss=1.6684219437262796e-05, grad=-0.01764762768274834 Epoch:65, w=1.9984456641827613, loss=1.3715169898293847e-05, grad=-0.016000515765691798 Epoch:66, w=1.9985907355257035, loss=1.1274479219506377e-05, grad=-0.014507134294228674 Epoch:67, w=1.9987222668766378, loss=9.268123006398985e-06, grad=-0.013153135093433596 Epoch:68, w=1.9988415219681517, loss=7.61880902783969e-06, grad=-0.011925509151381094 Epoch:69, w=1.9989496465844576, loss=6.262999634617916e-06, grad=-0.010812461630584766 Epoch:70, w=1.9990476795699081, loss=5.1484640551938914e-06, grad=-0.009803298545062233 Epoch:71, w=1.9991365628100501, loss=4.232266273994499e-06, grad=-0.008888324014190227 Epoch:72, w=1.999217150281112, loss=3.479110977946351e-06, grad=-0.008058747106198657 Epoch:73, w=1.999290216254875, loss=2.859983851026929e-06, grad=-0.00730659737628736 Epoch:74, w=1.9993564627377531, loss=2.3510338359374262e-06, grad=-0.006624648287833749 Epoch:75, w=1.9994165262155628, loss=1.932654303533636e-06, grad=-0.00600634778096983 Epoch:76, w=1.999470983768777, loss=1.5887277332523938e-06, grad=-0.005445755321414225 Epoch:77, w=1.9995203586170245, loss=1.3060048068548734e-06, grad=-0.004937484824748761 Epoch:78, w=1.9995651251461022, loss=1.0735939958924364e-06, grad=-0.004476652907771476 Epoch:79, w=1.9996057134657994, loss=8.825419799121559e-07, grad=-0.004058831969712499 Epoch:80, w=1.9996425135423248, loss=7.254887315754342e-07, grad=-0.003680007652538434 Epoch:81, w=1.999675878945041, loss=5.963839812987369e-07, grad=-0.003336540271635139 Epoch:82, w=1.999706130243504, loss=4.902541385825727e-07, grad=-0.0030251298462834106 Epoch:83, w=1.9997335580874436, loss=4.0301069098738336e-07, grad=-0.002742784393962546 Epoch:84, w=1.9997584259992822, loss=3.312926995781724e-07, grad=-0.002486791183860415 Epoch:85, w=1.9997809729060159, loss=2.723373231729343e-07, grad=-0.002254690673365515 Epoch:86, w=1.9998014154347876, loss=2.2387338352920307e-07, grad=-0.0020442528771848303 Epoch:87, w=1.9998199499942075, loss=1.8403387118941732e-07, grad=-0.0018534559419821999 Epoch:88, w=1.9998367546614149, loss=1.5128402140063082e-07, grad=-0.0016804667207292272 Epoch:89, w=1.9998519908930161, loss=1.2436218932547864e-07, grad=-0.0015236231601270707 Epoch:90, w=1.9998658050763347, loss=1.0223124683409346e-07, grad=-0.00138141833184946 Epoch:91, w=1.9998783299358769, loss=8.403862850836479e-08, grad=-0.0012524859542084599 Epoch:92, w=1.9998896858085284, loss=6.908348768398496e-08, grad=-0.0011355872651486187 Epoch:93, w=1.9998999817997325, loss=5.678969725349543e-08, grad=-0.0010295991204016808 Epoch:94, w=1.9999093168317574, loss=4.66836551287917e-08, grad=-0.0009335032024962627 Epoch:95, w=1.9999177805941268, loss=3.8376039345125727e-08, grad=-0.000846376236931512 Epoch:96, w=1.9999254544053418, loss=3.154680994333735e-08, grad=-0.0007673811214832978 Epoch:97, w=1.9999324119941766, loss=2.593287985380858e-08, grad=-0.0006957588834774301 Epoch:98, w=1.9999387202080534, loss=2.131797981222471e-08, grad=-0.000630821387685554 Epoch:99, w=1.9999444396553017, loss=1.752432687141379e-08, grad=-0.0005719447248348312 ***************************训练结束*************************** 训练后的输入值x:4.0, 训练后的预测值:7.999777758621207随机梯度下降的实现
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # @Time : 2021/11/24 16:52 # @Author : William Baker # @FileName: SGD.py # @Software: PyCharm # @Blog : https://blog.csdn.net/weixin_43051346 # SGD 随机梯度下降 import matplotlib.pyplot as plt # 准备数据集 x_data = [1.0, 2.0, 3.0] y_data = [2.0, 4.0, 6.0] # 初始化权重值 w w = 1.0 # 学习率 learning rate lr = 0.01 # 定义模型-线性模型 y=w*x def forward(x): return w*x """ # 定义损失函数loss def loss(x, y): y_pred = forward(x) # 预测值 y_hat return (y_pred - y) ** 2 """ # 定义损失函数loss def loss(xs, ys): for x, y in zip(xs, ys): y_pred = forward(x) # 预测值 y_hat return (y_pred - y) ** 2 """ # 定义梯度 def gradient(x, y): y_pred = forward(x) return 2 * x * (y_pred - y) """ # 定义梯度 def gradient(xs, ys): for x, y in zip(xs, ys): y_pred = forward(x) return 2 * x * (y_pred - y) epoch_list = [] cost_list = [] print('训练前的输入值x:{}, 训练前的预测值:{}n'.format(4.0, forward(4.0))) print("***************************开始训练***************************") """ # 开始训练 for epoch in range(100): for x, y in zip(x_data, y_data): loss_val = loss(x, y) # 预测的loss值 grad_val = gradient(x, y) # 预测的gradient w -= lr * grad_val # w = w - lr * gradient(w) 梯度下降的核心所在 print('Epoch:{}, w={}, loss={}, grad={}'.format(epoch, w, loss_val, grad_val)) epoch_list.append(epoch) cost_list.append(loss_val) """ # 开始训练 for epoch in range(100): loss_val = loss(x_data, y_data) # 预测的loss值 grad_val = gradient(x_data, y_data) # 预测的gradient w -= lr * grad_val # w = w - lr * gradient(w) 梯度下降的核心所在 print('Epoch:{}, w={}, loss={}, grad={}'.format(epoch, w, loss_val, grad_val)) epoch_list.append(epoch) cost_list.append(loss_val) print("***************************训练结束***************************n") print('训练后的输入值x:{}, 训练后的预测值:{}'.format(4.0, forward(4))) # 绘图 plt.plot(epoch_list, cost_list) plt.ylabel('Loss') plt.xlabel('Epoch') plt.show()
输出结果如下:
训练前的输入值x:4.0, 训练前的预测值:4.0 ***************************开始训练*************************** Epoch:0, w=1.18, loss=9.0, grad=-18.0 Epoch:1, w=1.3276, loss=6.0516, grad=-14.76 Epoch:2, w=1.448632, loss=4.069095840000001, grad=-12.103200000000001 Epoch:3, w=1.54787824, loss=2.7360600428160007, grad=-9.924624000000001 Epoch:4, w=1.6292601568, loss=1.8397267727894793, grad=-8.138191680000002 Epoch:5, w=1.695993328576, loss=1.237032282023645, grad=-6.6733171775999995 Epoch:6, w=1.75071452943232, loss=0.8317805064326984, grad=-5.472120085631998 Epoch:7, w=1.7955859141345025, loss=0.5592892125253471, grad=-4.487138470218241 Epoch:8, w=1.832380449590292, loss=0.3760660665020425, grad=-3.679453545578953 Epoch:9, w=1.8625519686640395, loss=0.2528668231159735, grad=-3.0171519073747426 Epoch:10, w=1.8872926143045123, loss=0.17002765186318064, grad=-2.474064564047289 Epoch:11, w=1.9075799437297, loss=0.1143265931128029, grad=-2.0287329425187792 Epoch:12, w=1.924215553858354, loss=0.07687320120904852, grad=-1.6635610128653973 Epoch:13, w=1.9378567541638503, loss=0.05168954049296429, grad=-1.3641200305496266 Epoch:14, w=1.9490425384143573, loss=0.03475604702746934, grad=-1.1185784250506963 Epoch:15, w=1.9582148814997729, loss=0.02336996602127023, grad=-0.917234308541568 Epoch:16, w=1.9657362028298138, loss=0.01571396515270217, grad=-0.7521321330040873 Epoch:17, w=1.9719036863204473, loss=0.010566070168677008, grad=-0.6167483490633536 Epoch:18, w=1.9769610227827668, loss=0.007104625581418294, grad=-0.5057336462319455 Epoch:19, w=1.9811080386818687, loss=0.004777150240945764, grad=-0.4147015899101998 Epoch:20, w=1.9845085917191323, loss=0.0032121558220119784, grad=-0.3400553037263663 Epoch:21, w=1.9872970452096885, loss=0.0021598535747208276, grad=-0.27884534905561864 Epoch:22, w=1.9895835770719446, loss=0.0014522855436422575, grad=-0.22865318622560515 Epoch:23, w=1.9914585331989945, loss=0.0009765167995450728, grad=-0.18749561270499804 Epoch:24, w=1.9929959972231754, loss=0.0006566098960140951, grad=-0.153746402418097 Epoch:25, w=1.994256717723004, loss=0.00044150449407990444, grad=-0.12607204998284338 Epoch:26, w=1.9952905085328632, loss=0.0002968676218193051, grad=-0.10337908098592763 Epoch:27, w=1.9961382169969477, loss=0.00019961378891129975, grad=-0.08477084640846044 Epoch:28, w=1.9968333379374972, loss=0.00013422031166397195, grad=-0.06951209405494119 Epoch:29, w=1.9974033371087476, loss=9.024973756285169e-05, grad=-0.056999917125050814 Epoch:30, w=1.997870736429173, loss=6.068392353726093e-05, grad=-0.046739932042541454 Epoch:31, w=1.9982540038719219, loss=4.080387018645243e-05, grad=-0.03832674427488314 Epoch:32, w=1.9985682831749758, loss=2.7436522313372105e-05, grad=-0.03142793030540503 Epoch:33, w=1.9988259922034801, loss=1.844831760351415e-05, grad=-0.02577090285043404 Epoch:34, w=1.9990373136068538, loss=1.2404648756606792e-05, grad=-0.021132140337359218 Epoch:35, w=1.9992105971576202, loss=8.340885823940048e-06, grad=-0.017328355076632107 Epoch:36, w=1.9993526896692486, loss=5.608411628015101e-06, grad=-0.014209251162835557 Epoch:37, w=1.999469205528784, loss=3.7710959786780433e-06, grad=-0.011651585953526222 Epoch:38, w=1.9995647485336028, loss=2.535684936061193e-06, grad=-0.009554300481887879 Epoch:39, w=1.9996430937975542, loss=1.7049945510078245e-06, grad=-0.0078345263951487 Epoch:40, w=1.9997073369139944, loss=1.1464383360988404e-06, grad=-0.006424311644025238 Epoch:41, w=1.9997600162694753, loss=7.708651371929851e-07, grad=-0.005267935548101121 Epoch:42, w=1.9998032133409698, loss=5.183297182483074e-07, grad=-0.004319707149441854 Epoch:43, w=1.9998386349395951, loss=3.4852490255049745e-07, grad=-0.0035421598625440254 Epoch:44, w=1.999867680650468, loss=2.343481444747825e-07, grad=-0.002904571087285035 Epoch:45, w=1.9998914981333837, loss=1.5757569234467452e-07, grad=-0.0023817482915724497 Epoch:46, w=1.9999110284693746, loss=1.0595389553262853e-07, grad=-0.0019530335990900483 Epoch:47, w=1.9999270433448872, loss=7.124339935627219e-08, grad=-0.0016014875512553317 Epoch:48, w=1.9999401755428075, loss=4.790406172717297e-08, grad=-0.0013132197920295852 Epoch:49, w=1.9999509439451022, loss=3.221069110544675e-08, grad=-0.0010768402294658586 Epoch:50, w=1.9999597740349837, loss=2.1658468699093255e-08, grad=-0.0008830089881577408 Epoch:51, w=1.9999670147086868, loss=1.456315435353612e-08, grad=-0.0007240673702959555 Epoch:52, w=1.9999729520611231, loss=9.792264987127843e-09, grad=-0.0005937352436369281 Epoch:53, w=1.9999778206901209, loss=6.584318977344762e-09, grad=-0.00048686289978228103 Epoch:54, w=1.999981812965899, loss=4.427296080402076e-09, grad=-0.00039922757782306917 Epoch:55, w=1.9999850866320372, loss=2.976913884501124e-09, grad=-0.00032736661381704835 Epoch:56, w=1.9999877710382705, loss=2.001676895938556e-09, grad=-0.00026844062332997964 Epoch:57, w=1.9999899722513819, loss=1.3459275448290849e-09, grad=-0.0002201213111305833 Epoch:58, w=1.999991777246133, loss=9.050016811366642e-10, grad=-0.00018049947512643882 Epoch:59, w=1.999993257341829, loss=6.08523130391911e-10, grad=-0.00014800956960314693 Epoch:60, w=1.9999944710203, loss=4.091709529057039e-10, grad=-0.0001213678470790569 Epoch:61, w=1.999995466236646, loss=2.7512654872200957e-10, grad=-9.952163460269503e-05 Epoch:62, w=1.9999962823140496, loss=1.8499509135681353e-10, grad=-8.160774037335727e-05 Epoch:63, w=1.9999969514975207, loss=1.2439069943862355e-10, grad=-6.691834710892408e-05 Epoch:64, w=1.9999975002279669, loss=8.364030629083358e-11, grad=-5.4873044625480816e-05 Epoch:65, w=1.999997950186933, loss=5.623974196274511e-11, grad=-4.4995896598010177e-05 Epoch:66, w=1.999998319153285, loss=3.781560249181731e-11, grad=-3.689663520844988e-05 Epoch:67, w=1.9999986217056938, loss=2.5427211109406962e-11, grad=-3.0255240867305133e-05 Epoch:68, w=1.9999988697986688, loss=1.709725675158115e-11, grad=-2.4809297512362605e-05 Epoch:69, w=1.9999990732349084, loss=1.1496195438197204e-11, grad=-2.0343623958751778e-05 Epoch:70, w=1.999999240052625, loss=7.730041815607078e-12, grad=-1.66817716493739e-05 Epoch:71, w=1.9999993768431525, loss=5.197680114303315e-12, grad=-1.3679052749182574e-05 Epoch:72, w=1.999999489011385, loss=3.4949201104515554e-12, grad=-1.1216823256887665e-05 Epoch:73, w=1.9999995809893356, loss=2.349984281723007e-12, grad=-9.197795069582071e-06 Epoch:74, w=1.9999996564112552, loss=1.5801294318344075e-12, grad=-7.5421919589757636e-06 Epoch:75, w=1.9999997182572293, loss=1.0624790294527732e-12, grad=-6.184597404867986e-06 Epoch:76, w=1.9999997689709281, loss=7.144108997944157e-13, grad=-5.071369873377307e-06 Epoch:77, w=1.9999998105561612, loss=4.803698890710119e-13, grad=-4.158523296382555e-06 Epoch:78, w=1.999999844656052, loss=3.230007131690541e-13, grad=-3.409989101754718e-06 Epoch:79, w=1.9999998726179626, loss=2.1718567983289397e-13, grad=-2.796191065357334e-06 Epoch:80, w=1.9999998955467293, loss=1.4603565098387233e-13, grad=-2.2928766725272e-06 Epoch:81, w=1.999999914348318, loss=9.81943716992902e-14, grad=-1.880158871259141e-06 Epoch:82, w=1.9999999297656208, loss=6.60258955032161e-14, grad=-1.5417302741127514e-06 Epoch:83, w=1.999999942407809, loss=4.439581219624794e-14, grad=-1.2642188256251075e-06 Epoch:84, w=1.9999999527744035, loss=2.985174434787262e-14, grad=-1.0366594409561003e-06 Epoch:85, w=1.9999999612750108, loss=2.0072312657907758e-14, grad=-8.500607364680945e-07 Epoch:86, w=1.9999999682455087, loss=1.3496623055941361e-14, grad=-6.97049804543326e-07 Epoch:87, w=1.9999999739613172, loss=9.075129393581547e-15, grad=-5.715808413242485e-07 Epoch:88, w=1.99999997864828, loss=6.1021170708499814e-15, grad=-4.686962924438376e-07 Epoch:89, w=1.9999999824915897, loss=4.103063445617242e-15, grad=-3.843309563933417e-07 Epoch:90, w=1.9999999856431034, loss=2.7588998272437547e-15, grad=-3.151513823240748e-07 Epoch:91, w=1.999999988227345, loss=1.8550843478908236e-15, grad=-2.5842414075327724e-07 Epoch:92, w=1.999999990346423, loss=1.2473586828983886e-15, grad=-2.1190779264657067e-07 Epoch:93, w=1.9999999920840668, loss=8.387239228207162e-16, grad=-1.737643842147918e-07 Epoch:94, w=1.9999999935089348, loss=5.639580129513639e-16, grad=-1.4248680102468825e-07 Epoch:95, w=1.9999999946773266, loss=3.79205343002729e-16, grad=-1.1683917300331359e-07 Epoch:96, w=1.9999999956354078, loss=2.54977656183392e-16, grad=-9.580811877185624e-08 Epoch:97, w=1.9999999964210344, loss=1.7144698904287566e-16, grad=-7.856266037720161e-08 Epoch:98, w=1.9999999970652482, loss=1.1528095085501517e-16, grad=-6.44213802303284e-08 Epoch:99, w=1.9999999975935034, loss=7.751490791422244e-17, grad=-5.282553061647377e-08 ***************************训练结束*************************** 训练后的输入值x:4.0, 训练后的预测值:7.999999990374014思考题
如下的实验对比发现,zip函数写在损失计算函数里面和写在训练开始加载数据的这个位置中,得到的结果中,每次Epoch的loss值不一样,有小伙伴知道这是什么原因吗?知道的小伙伴可以把答案打在评论区哦!
#!/usr/bin/env python3 # -*- coding: utf-8 -*- # @Time : 2021/11/24 16:52 # @Author : William Baker # @FileName: SGD.py # @Software: PyCharm # @Blog : https://blog.csdn.net/weixin_43051346 # SGD 随机梯度下降 import matplotlib.pyplot as plt # 准备数据集 x_data = [1.0, 2.0, 3.0] y_data = [2.0, 4.0, 6.0] # 初始化权重值 w w = 1.0 # 学习率 learning rate lr = 0.01 # 定义模型-线性模型 y=w*x def forward(x): return w*x # 定义损失函数loss def loss(x, y): y_pred = forward(x) # 预测值 y_hat return (y_pred - y) ** 2 # 定义梯度 def gradient(x, y): y_pred = forward(x) return 2 * x * (y_pred - y) epoch_list = [] cost_list = [] print('训练前的输入值x:{}, 训练前的预测值:{}n'.format(4.0, forward(4.0))) print("***************************开始训练***************************") # 开始训练 for epoch in range(100): for x, y in zip(x_data, y_data): loss_val = loss(x, y) # 预测的loss值 grad_val = gradient(x, y) # 预测的gradient w -= lr * grad_val # w = w - lr * gradient(w) 梯度下降的核心所在 print('Epoch:{}, w={}, loss={}, grad={}'.format(epoch, w, loss_val, grad_val)) epoch_list.append(epoch) cost_list.append(loss_val) """ # 开始训练 for epoch in range(100): loss_val = loss(x_data, y_data) # 预测的loss值 grad_val = gradient(x_data, y_data) # 预测的gradient w -= lr * grad_val # w = w - lr * gradient(w) 梯度下降的核心所在 print('Epoch:{}, w={}, loss={}, grad={}'.format(epoch, w, loss_val, grad_val)) epoch_list.append(epoch) cost_list.append(loss_val) """ print("***************************训练结束***************************n") print('训练后的输入值x:{}, 训练后的预测值:{}'.format(4.0, forward(4))) # 绘图 plt.plot(epoch_list, cost_list) plt.ylabel('Loss') plt.xlabel('Epoch') plt.show()
训练前的输入值x:4.0, 训练前的预测值:4.0 ***************************开始训练*************************** Epoch:0, w=1.260688, loss=7.315943039999998, grad=-16.2288 Epoch:1, w=1.453417766656, loss=3.9987644858206908, grad=-11.998146585599997 Epoch:2, w=1.5959051959019805, loss=2.1856536232765476, grad=-8.87037374849311 Epoch:3, w=1.701247862192685, loss=1.1946394387269013, grad=-6.557973756745939 Epoch:4, w=1.7791289594933983, loss=0.6529686924601721, grad=-4.848388694047353 Epoch:5, w=1.836707389300983, loss=0.3569010862285927, grad=-3.584471942173538 Epoch:6, w=1.8792758133988885, loss=0.195075792793724, grad=-2.650043120512205 Epoch:7, w=1.910747160155559, loss=0.10662496249654511, grad=-1.9592086795121197 Epoch:8, w=1.9340143044689266, loss=0.05827931013158195, grad=-1.4484664872674653 Epoch:9, w=1.9512159834655312, loss=0.03185443548946761, grad=-1.0708686556346834 Epoch:10, w=1.9639333911678687, loss=0.017411068491745587, grad=-0.7917060475345892 Epoch:11, w=1.9733355232910992, loss=0.009516580701123755, grad=-0.5853177814148953 Epoch:12, w=1.9802866323953892, loss=0.005201593933418656, grad=-0.4327324596134101 Epoch:13, w=1.9854256707695, loss=0.0028430988290765965, grad=-0.3199243001817109 Epoch:14, w=1.9892250235079405, loss=0.0015539873076143675, grad=-0.2365238742159388 Epoch:15, w=1.9920339305797026, loss=0.000849381853184108, grad=-0.17486493849433593 Epoch:16, w=1.994110589284741, loss=0.00046425703027525234, grad=-0.12927974740812687 Epoch:17, w=1.9956458879852805, loss=0.0002537546444534916, grad=-0.09557806861579543 Epoch:18, w=1.9967809527381737, loss=0.00013869778028681822, grad=-0.07066201306448505 Epoch:19, w=1.9976201197307648, loss=7.580974250901852e-05, grad=-0.052241274202728505 Epoch:20, w=1.998240525958391, loss=4.143625836981207e-05, grad=-0.03862260091336722 Epoch:21, w=1.99869919972735, loss=2.2648322641186416e-05, grad=-0.028554152326460525 Epoch:22, w=1.9990383027488265, loss=1.2379170770717257e-05, grad=-0.021110427464781978 Epoch:23, w=1.9992890056818404, loss=6.766234806804613e-06, grad=-0.01560719234984198 Epoch:24, w=1.999474353368653, loss=3.6983037320320918e-06, grad=-0.011538584590544687 Epoch:25, w=1.9996113831376856, loss=2.021427113440456e-06, grad=-0.008530614050808794 Epoch:26, w=1.9997126908902887, loss=1.104876146204831e-06, grad=-0.006306785335127074 Epoch:27, w=1.9997875889274812, loss=6.039056715601388e-07, grad=-0.00466268207967957 Epoch:28, w=1.9998429619451539, loss=3.300841106907982e-07, grad=-0.0034471768136938863 Epoch:29, w=1.9998838998815958, loss=1.8041811041234527e-07, grad=-0.0025485391844828342 Epoch:30, w=1.9999141657892625, loss=9.861333372463779e-08, grad=-0.0018841656015560204 Epoch:31, w=1.9999365417379913, loss=5.390029618456292e-08, grad=-0.0013929862392156878 Epoch:32, w=1.9999530845453979, loss=2.946094426616231e-08, grad=-0.0010298514424817995 Epoch:33, w=1.9999653148414271, loss=1.6102828713572706e-08, grad=-0.0007613815296476645 Epoch:34, w=1.999974356846045, loss=8.801520081617991e-09, grad=-0.0005628985014531906 Epoch:35, w=1.9999810417085633, loss=4.810754502894822e-09, grad=-0.0004161576169003922 Epoch:36, w=1.9999859839076413, loss=2.6294729403166827e-09, grad=-0.0003076703200690645 Epoch:37, w=1.9999896377347262, loss=1.43722319226415e-09, grad=-0.00022746435967313516 Epoch:38, w=1.999992339052936, loss=7.855606621992112e-10, grad=-0.00016816713067413502 Epoch:39, w=1.9999943361699042, loss=4.293735011907528e-10, grad=-0.00012432797771566584 Epoch:40, w=1.9999958126624442, loss=2.3468792720119317e-10, grad=-9.191716585732479e-05 Epoch:41, w=1.999996904251097, loss=1.2827625139303397e-10, grad=-6.795546372551087e-05 Epoch:42, w=1.999997711275687, loss=7.011351996364471e-11, grad=-5.0240289795056015e-05 Epoch:43, w=1.9999983079186507, loss=3.832280433642867e-11, grad=-3.714324913239864e-05 Epoch:44, w=1.9999987490239537, loss=2.0946563973304985e-11, grad=-2.7460449796734565e-05 Epoch:45, w=1.9999990751383971, loss=1.1449019716984442e-11, grad=-2.0301840059744336e-05 Epoch:46, w=1.9999993162387186, loss=6.257830771044664e-12, grad=-1.5009393983689279e-05 Epoch:47, w=1.9999994944870796, loss=3.4204191158124504e-12, grad=-1.109662508014253e-05 Epoch:48, w=1.9999996262682318, loss=1.8695403191196464e-12, grad=-8.20386808086937e-06 Epoch:49, w=1.999999723695619, loss=1.0218575233353146e-12, grad=-6.065218119744031e-06 Epoch:50, w=1.9999997957248556, loss=5.585291664185541e-13, grad=-4.484088535150477e-06 Epoch:51, w=1.9999998489769344, loss=3.0528211874783223e-13, grad=-3.3151404608133817e-06 Epoch:52, w=1.9999998883468353, loss=1.6686178282138566e-13, grad=-2.4509231284497446e-06 Epoch:53, w=1.9999999174534755, loss=9.120368570648034e-14, grad=-1.811996877876254e-06 Epoch:54, w=1.999999938972364, loss=4.9850314593866976e-14, grad=-1.3396310407642886e-06 Epoch:55, w=1.9999999548815364, loss=2.7247296013817913e-14, grad=-9.904052991061008e-07 Epoch:56, w=1.9999999666433785, loss=1.4892887826055098e-14, grad=-7.322185204827747e-07 Epoch:57, w=1.9999999753390494, loss=8.140187918760348e-15, grad=-5.413379398078177e-07 Epoch:58, w=1.9999999817678633, loss=4.449282094197275e-15, grad=-4.002176350326181e-07 Epoch:59, w=1.9999999865207625, loss=2.4318985397157373e-15, grad=-2.9588569994132286e-07 Epoch:60, w=1.999999990034638, loss=1.3292325355918982e-15, grad=-2.1875184863517916e-07 Epoch:61, w=1.9999999926324883, loss=7.265349176868077e-16, grad=-1.617258700292723e-07 Epoch:62, w=1.99999999455311, loss=3.971110830662586e-16, grad=-1.195658771990793e-07 Epoch:63, w=1.9999999959730488, loss=2.1705387408049341e-16, grad=-8.839649012770678e-08 Epoch:64, w=1.9999999970228268, loss=1.186377771034419e-16, grad=-6.53525820126788e-08 Epoch:65, w=1.9999999977989402, loss=6.484530240933061e-17, grad=-4.8315948575350376e-08 Epoch:66, w=1.9999999983727301, loss=3.544328347681514e-17, grad=-3.5720557178819945e-08 Epoch:67, w=1.9999999987969397, loss=1.937267496512019e-17, grad=-2.6408640607655798e-08 Epoch:68, w=1.999999999110563, loss=1.0588762836876607e-17, grad=-1.9524227568012975e-08 Epoch:69, w=1.9999999993424284, loss=5.78763006728202e-18, grad=-1.4434496264925656e-08 Epoch:70, w=1.9999999995138495, loss=3.1634161455080883e-18, grad=-1.067159693945996e-08 Epoch:71, w=1.9999999996405833, loss=1.7290652800585402e-18, grad=-7.88963561149103e-09 Epoch:72, w=1.999999999734279, loss=9.45076322860815e-19, grad=-5.832902161273523e-09 Epoch:73, w=1.9999999998035491, loss=5.165625480570949e-19, grad=-4.31233715403323e-09 Epoch:74, w=1.9999999998547615, loss=2.8234328489219424e-19, grad=-3.188159070077745e-09 Epoch:75, w=1.9999999998926234, loss=1.5432429879714383e-19, grad=-2.3570478902001923e-09 Epoch:76, w=1.9999999999206153, loss=8.435055999638128e-20, grad=-1.7425900722400911e-09 Epoch:77, w=1.9999999999413098, loss=4.610497285725064e-20, grad=-1.2883241140571045e-09 Epoch:78, w=1.9999999999566096, loss=2.520026245264157e-20, grad=-9.524754318590567e-10 Epoch:79, w=1.9999999999679208, loss=1.377407569393045e-20, grad=-7.041780492045291e-10 Epoch:80, w=1.9999999999762834, loss=7.528673013117128e-21, grad=-5.206075570640678e-10 Epoch:81, w=1.999999999982466, loss=4.115053962078213e-21, grad=-3.8489211817704927e-10 Epoch:82, w=1.9999999999870368, loss=2.2492314589577438e-21, grad=-2.845563784603655e-10 Epoch:83, w=1.999999999990416, loss=1.229387284413716e-21, grad=-2.1037571684701106e-10 Epoch:84, w=1.9999999999929146, loss=6.719695441682722e-22, grad=-1.5553425214420713e-10 Epoch:85, w=1.9999999999947617, loss=3.6730159234427135e-22, grad=-1.1499068364173581e-10 Epoch:86, w=1.9999999999961273, loss=2.0073851892168518e-22, grad=-8.500933290633839e-11 Epoch:87, w=1.999999999997137, loss=1.0972931813778698e-22, grad=-6.285105769165966e-11 Epoch:88, w=1.9999999999978835, loss=5.996996411023123e-23, grad=-4.646416584819235e-11 Epoch:89, w=1.9999999999984353, loss=3.2777893208522223e-23, grad=-3.4351188560322043e-11 Epoch:90, w=1.9999999999988431, loss=1.791878298003441e-23, grad=-2.539835008974478e-11 Epoch:91, w=1.9999999999991447, loss=9.796529104915932e-24, grad=-1.8779644506139448e-11 Epoch:92, w=1.9999999999993676, loss=5.353229824352417e-24, grad=-1.3882228699912957e-11 Epoch:93, w=1.9999999999995324, loss=2.926260595255618e-24, grad=-1.0263789818054647e-11 Epoch:94, w=1.9999999999996543, loss=1.5996332109454424e-24, grad=-7.58859641791787e-12 Epoch:95, w=1.9999999999997444, loss=8.746960714572049e-25, grad=-5.611511255665391e-12 Epoch:96, w=1.999999999999811, loss=4.774848841557949e-25, grad=-4.1460168631601846e-12 Epoch:97, w=1.9999999999998603, loss=2.6081713678869703e-25, grad=-3.064215547965432e-12 Epoch:98, w=1.9999999999998967, loss=1.4248800100554526e-25, grad=-2.2648549702353193e-12 Epoch:99, w=1.9999999999999236, loss=7.82747233205549e-26, grad=-1.6786572132332367e-12 ***************************训练结束*************************** 训练后的输入值x:4.0, 训练后的预测值:7.9999999999996945
写到这里,差不多本文也就要结束了。如果我的这篇文章帮助到了你,那我也会感到很高兴,一个人能走多远,在于与谁同行。
参考文章
《PyTorch深度学习实践》完结合集
↩︎PyTorch 深度学习实践 第3讲
↩︎Pytorch深度学习实践 第三讲 梯度下降算法
↩︎
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)