滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量。
1、滑动平均求解对象初始化
ema = tf.train.ExponentialMovingAverage(decay,num_updates)
参数decay
`shadow_variable = decay * shadow_variable + (1 - decay) * variable`
参数num_updates
`min(decay, (1 + num_updates) / (10 + num_updates))`
2、添加/更新变量
添加目标变量,为之维护影子变量
注意维护不是自动的,需要每轮训练中运行此句,所以一般都会使用tf.control_dependencies使之和train_op绑定,以至于每次train_op都会更新影子变量
ema.apply([var0, var1])
3、获取影子变量值
这一步不需要定义图中,从影子变量集合中提取目标值
sess.run(ema.average([var0, var1]))
4、保存&载入影子变量
我们知道,在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身。
保存影子变量
建立tf.train.ExponentialMovingAverage对象后,Saver正常保存就会存入影子变量,命名规则是"v/ExponentialMovingAverage"对应变量”v“
import tensorflow as tf if __name__ == "__main__": v = tf.Variable(0.,name="v") #设置滑动平均模型的系数 ema = tf.train.ExponentialMovingAverage(0.99) #设置变量v使用滑动平均模型,tf.all_variables()设置所有变量 op = ema.apply([v]) #获取变量v的名字 print(v.name) #v:0 #创建一个保存模型的对象 save = tf.train.Saver() sess = tf.Session() #初始化所有变量 init = tf.initialize_all_variables() sess.run(init) #给变量v重新赋值 sess.run(tf.assign(v,10)) #应用平均滑动设置 sess.run(op) #保存模型文件 save.save(sess,"./model.ckpt") #输出变量v之前的值和使用滑动平均模型之后的值 print(sess.run([v,ema.average(v)])) #[10.0, 0.099999905]
载入影子变量并映射到变量
利用了Saver载入模型的变量名映射功能,实际上对所有的变量都可以如此 *** 作『TensorFlow』模型载入方法汇总
v = tf.Variable(1.,name="v") #定义模型对象 saver = tf.train.Saver({"v/ExponentialMovingAverage":v}) sess = tf.Session() saver.restore(sess,"./model.ckpt") print(sess.run(v)) #0.0999999
这里特别需要注意的一个地方就是,在使用tf.train.Saver函数中,所传递的模型参数是{"v/ExponentialMovingAverage":v}而不是{"v":v},如果你使用的是后面的参数,那么你得到的结果将是10而不是0.09,那是因为后者获取的是变量本身而不是影子变量。
使用这种方式来读取模型文件的时候,还需要输入一大串的变量名称。
variables_to_restore函数的使用
v = tf.Variable(1.,name="v") #滑动模型的参数的大小并不会影响v的值 ema = tf.train.ExponentialMovingAverage(0.99) print(ema.variables_to_restore()) #{'v/ExponentialMovingAverage':} sess = tf.Session() saver = tf.train.Saver(ema.variables_to_restore()) saver.restore(sess,"./model.ckpt") print(sess.run(v)) #0.0999999
variables_to_restore会识别网络中的变量,并自动生成影子变量名。
通过使用variables_to_restore函数,可以使在加载模型的时候将影子变量直接映射到变量的本身,所以我们在获取变量的滑动平均值的时候只需要获取到变量的本身值而不需要去获取影子变量。
5、官方文档例子
官方文档中将每次apply更新就会自动训练一边模型,实际上可以反过来两者关系,《tf实战google》P128中有示例
| Example usage when creating a training model: | | ```python | # Create variables. | var0 = tf.Variable(...) | var1 = tf.Variable(...) | # ... use the variables to build a training model... | ... | # Create an op that applies the optimizer. This is what we usually | # would use as a training op. | opt_op = opt.minimize(my_loss, [var0, var1]) | | # Create an ExponentialMovingAverage object | ema = tf.train.ExponentialMovingAverage(decay=0.9999) | | with tf.control_dependencies([opt_op]): | # Create the shadow variables, and add ops to maintain moving averages | # of var0 and var1. This also creates an op that will update the moving | # averages after each training step. This is what we will use in place | # of the usual training op. | training_op = ema.apply([var0, var1]) | | ...train the model by running training_op... | ```
6、batch_normal的例子
和上面不太一样的是,batch_normal中不太容易绑定到train_op(位于函数体外面),则强行将两个variable的输出过程化为节点,绑定给参数更新步骤
def batch_norm(x,beta,gamma,phase_train,scope='bn',decay=0.9,eps=1e-5): with tf.variable_scope(scope): # beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0), trainable=True) # gamma = tf.get_variable(name='gamma', shape=[n_out], # initializer=tf.random_normal_initializer(1.0, stddev), trainable=True) batch_mean,batch_var = tf.nn.moments(x,[0,1,2],name='moments') ema = tf.train.ExponentialMovingAverage(decay=decay) def mean_var_with_update(): ema_apply_op = ema.apply([batch_mean,batch_var]) with tf.control_dependencies([ema_apply_op]): return tf.identity(batch_mean),tf.identity(batch_var) # identity之后会把Variable转换为Tensor并入图中, # 否则由于Variable是独立于Session的,不会被图控制control_dependencies限制 mean,var = tf.cond(phase_train, mean_var_with_update, lambda: (ema.average(batch_mean),ema.average(batch_var))) normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps) return normed
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持考高分网。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)