Tensorflow中的低阶API的使用

Tensorflow中的低阶API的使用,第1张

9.优化器,深度学习优化算法大概经历了SGD->SGDM->NAG->Adagrad->Adadelta(RMSprop)->Adam->Nadam这样的发展历程,在keras.optimizers子模块中,他们基本上都有对应的类的实现。


#优化器
import tensorflow as tf
optimizer=tf.keras.optimizers.SGD(learning_rate=5e-4)#基础的随机下降梯度算法
optimizer=tf.keras.optimizers.Adam(learning_rate=5e-4)#使用最多的Adam算法

10.梯度记录器

梯度记录器实现模型的反向传播,主要涉及如下两个方法。


tape.gradient:tape.gradient:求解梯度,需要传入两个参数,即损失值和参数,分别是因变量和自变量。


optimizer.apply_gradients:更新模型参数,依然传入两个参数,即梯度和变量。


变量将会按照w_new=w-lr*grads的方式进行参数的更新。


#使用低阶API进行梯度更新
with tf.GradientTape()as tape:
    loss=tf.losses.MeanSquaredError()(model(x),y)
grads=tape.gradient(loss,varibles)
optimizer.apply_gradients(grads_and_vars=zip(grads,variables))

11.低阶API实现线性回归

用数据拟合y=2x的线性函数,使用100条数据,weight和bias直接初始化为1,损失函数为MSE,学习率为0.1,训练一个epoch。


import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
input_x=np.float32(np.linspace(-1,1,100))#在区间[-1,1)内产生100个数的等差数列,作为输入
input_y=2*input_x+np.random.randn(*input_x.shape)*0.3 #y=2x+随机噪声
weight=tf.Variable(1.,dtype=tf.float32,name='weight')
bias=tf.Variable(1.,dtype=tf.float32,name='bias')
def model(x):#定义了线性模型y=weight*x+bias
    pred=tf.multiply(x,weight)+bias
    return pred
step=0
opt=tf.optimizers.Adam(1e-1)#选择优化器,是一种梯度下降的方法
for x,y in zip(input_x,input_y):
    x=np.reshape(x,[1])
    y=np.reshape(y,[1])
    step=step+1
    with tf.GradientTape() as tape:
        loss=tf.losses.MeanSquaredError()(model(x),y)#连续数据的预测,损失函数用MSE
    grads=tape.gradient(loss,[weight,bias])#计算梯度
    opt.apply_gradients(zip(grads,[weight,bias]))#更新参数weight和bias
    print("step:",step,"Traing Loss",loss.numpy())
    #用matplotlib可视化原始数据和预测的模型
    plt.plot(input_x,input_y,'ro',label='original data')
    plt.plot(input_x,model(input_x),label='predicted value')
    plt.plot(input_x,2*input_x,label='y=2x')
    plt.legend()
    plt.show()
    print(weight)
    print(bias)

运行结果: 

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/langs/569969.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-09
下一篇 2022-04-09

发表评论

登录后才能评论

评论列表(0条)

保存