最简单的解决方案是创建不同的会话,为每个模型使用单独的图形:
# Build a graph containing `net1`.with tf.Graph().as_default() as net1_graph: net1 = CreateAlexNet() saver1 = tf.train.Saver(...)sess1 = tf.Session(graph=net1_graph)saver1.restore(sess1, 'epoch_10.ckpt')# Build a separate graph containing `net2`.with tf.Graph().as_default() as net2_graph: net2 = CreateAlexNet() saver2 = tf.train.Saver(...)sess2 = tf.Session(graph=net1_graph)saver2.restore(sess2, 'epoch_50.ckpt')
如果由于某种原因该方法不起作用,并且您必须使用一个
tf.Session(例如,因为您希望将来自两个网络的结果合并到另一个TensorFlow计算中),最好的解决方案是:
- 像已经做的那样在名称范围中创建不同的网络,并且
tf.train.Saver
为两个网络创建单独的实例,并带有一个附加参数以重新映射变量名称。
当构建的储户,就可以通过一本字典作为
var_list参数,在检查点映射变量的名称(即没有名称范围前缀)给
tf.Variable你的每个模型创建的对象。
您可以以
var_list编程方式进行构建,并且应该能够执行以下 *** 作:
with tf.name_scope("net1"): net1 = CreateAlexNet()with tf.name_scope("net2"): net2 = CreateAlexNet()# Strip off the "net1/" prefix to get the names of the variables in the checkpoint.net1_varlist = {v.name.lstrip("net1/"): v for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}net1_saver = tf.train.Saver(var_list=net1_varlist)# Strip off the "net2/" prefix to get the names of the variables in the checkpoint.net2_varlist = {v.name.lstrip("net2/"): v for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}net2_saver = tf.train.Saver(var_list=net2_varlist)# ...net1_saver.restore(sess, "epoch_10.ckpt")net2_saver.restore(sess, "epoch_50.ckpt")
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)