看来您在第1版中修剪模型的方式很好;根据您的错误消息,无法保存生成的修剪模型,因为它不是“可跟踪的”,这是使用保存模型的必要条件
tf.saved_model.save。生成可跟踪对象的一种方法是从
tf.Module类继承,如使用SavedModel格式和具体函数的指南中所述。下面是一个示例,尝试保存一个
tf.function对象(由于该对象不可跟踪而失败),从继承
tf.module并保存生成的对象:
(使用Python版本3.7.6,TensorFlow版本2.1.0和NumPy版本1.18.1)
import tensorflow as tf, numpy as np# Define a random TensorFlow function and generate a reference outputconv_filter = tf.random.normal([1, 2, 4, 2], seed=1254)@tf.functiondef conv_model(x): return tf.nn.conv2d(x, conv_filter, 1, "SAME")input_tensor = tf.ones([1, 2, 3, 4])output_tensor = conv_model(input_tensor)print("Original model outputs:", output_tensor, sep="n")# Try saving the model: it won't work because a tf.function is not trackableexport_dir = "./tmp/"try: tf.saved_model.save(conv_model, export_dir)except ValueError: print( "Can't save {} object because it's not trackable".format(type(conv_model)))# Now define a trackable object by inheriting from the tf.Module classclass MyModule(tf.Module): @tf.function def __call__(self, x): return conv_model(x)# Instantiate the trackable object, and call once to trace-compile a graphmodule_func = MyModule()module_func(input_tensor)tf.saved_model.save(module_func, export_dir)# Restore the model and verify that the outputs are consistentrestored_model = tf.saved_model.load(export_dir)restored_output_tensor = restored_model(input_tensor)print("Restored model outputs:", restored_output_tensor, sep="n")if np.array_equal(output_tensor.numpy(), restored_output_tensor.numpy()): print("Outputs are consistent :)")else: print("Outputs are NOT consistent :(")
控制台输出:
Original model outputs:tf.Tensor([[[[-2.3629642 1.2904963 ] [-2.3629642 1.2904963 ] [-0.02110204 1.3400152 ]] [[-2.3629642 1.2904963 ] [-2.3629642 1.2904963 ] [-0.02110204 1.3400152 ]]]], shape=(1, 2, 3, 2), dtype=float32)Can't save <class 'tensorflow.python.eager.def_function.Function'> objectbecause it's not trackableRestored model outputs:tf.Tensor([[[[-2.3629642 1.2904963 ] [-2.3629642 1.2904963 ] [-0.02110204 1.3400152 ]] [[-2.3629642 1.2904963 ] [-2.3629642 1.2904963 ] [-0.02110204 1.3400152 ]]]], shape=(1, 2, 3, 2), dtype=float32)Outputs are consistent :)
因此,您应该尝试按以下方式修改代码:
svmod = tf.saved_model.load(fn) #version 1svmod2 = svmod.prune(feeds=['foo:0'], fetches=['bar:0'])class Exportable(tf.Module): @tf.function def __call__(self, model_inputs): return svmod2(model_inputs)svmod2_export = Exportable()svmod2_export(typical_input) # call once with typical input to trace-compiletf.saved_model.save(svmod2_export, '/tmp/saved_model/')
如果您不想继承自
tf.Module,则可以替换实例代码,实例化一个
tf.Module对象并添加
tf.function方法/可调用属性,如下所示:
to_export = tf.Module()to_export.call = tf.function(conv_model)to_export.call(input_tensor)tf.saved_model.save(to_export, export_dir)restored_module = tf.saved_model.load(export_dir)restored_func = restored_module.call
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)