关于global

关于global,第1张

关于global 关于global_step参数
  1. global_step与learning_rate的关系
  2. global_step与estimator的关系

1 global_step与learning_rate的关系
learning_rate:在梯度下降中学习率大小很关键,过大容易造成震荡,过小容易陷入局部最优值以及降低收敛速度加长学习时间。起初学习率过大是加速到达最低点,解决学习率国小的问题;随着训练步数的增加,学习率呈指数递减,防止学习率过大而到达不了最低点,使得最后趋于平稳,到达一个稳定的学习状态
learning_rate衰减方式:
lr_init - 初始学习率
decay_rate - 学习衰减率
global_step - 全局训练步数(每训练一个batch就+1)
decay_step - 衰减步数(训练100 或 10000步之后开始改变学习率)

		阶梯式l  r_new = lr_init * decay_rate^(global_step/decay_step)
		 				当且仅当global_step/decay_step 为整数时学习率才会改变
		 				 tfAPI:
		指数形式 lr_new = lr_init * decay_rate^(global_step/decay_step)
		               每训练一个batch学习率都会更新
 global_step = tf.Variable(0)
 
# 通过exponential_decay函数生成学习率
#staircase = True为阶梯式更新lr,staircase = False为指数形式更新lr
learning_rate = tf.train.exponential_decay(0.1, global_step, 100, 0.96, staircase = True)
 
# 使用指数衰减的学习率。在minimize函数中传入global_step将自动更新
# global_ste参数,从而使得学习率也得到相应更新
learning_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(..my loss.., global_step = global_step)
 

2 global_step与estimator的关系
estimator.train日志显示 global_step/sec,没有得到理想的loss和每step时间

'''
                every_n_iter : 指定打印的step数
            '''
            logging_hook = tf.train.LoggingTensorHook({"loss":total_loss,"step":global_steps},every_n_iter=1)
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                training_hooks=[logging_hook],  #add
                scaffold_fn=scaffold_fn)  # 钩子,这里用来将BERT中的参数作为我们模型的初始值

个人理解,仅供参考。
参考博客:global_step用法与理解

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5480729.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-12
下一篇 2022-12-12

发表评论

登录后才能评论

评论列表(0条)

保存