TF2.0中的saved_model.prune()

TF2.0中的saved_model.prune(),第1张

TF2.0中的saved_model.prune()

看来您在第1版中修剪模型的方式很好;根据您的错误消息,无法保存生成的修剪模型,因为它不是“可跟踪的”,这是使用保存模型的必要条件

tf.saved_model.save
。生成可跟踪对象的一种方法是从
tf.Module
继承,如使用SavedModel格式和具体函数的指南中所述。下面是一个示例,尝试保存一个
tf.function
对象(由于该对象不可跟踪而失败),从继承
tf.module
并保存生成的对象:

(使用Python版本3.7.6,TensorFlow版本2.1.0和NumPy版本1.18.1)

import tensorflow as tf, numpy as np# Define a random TensorFlow function and generate a reference outputconv_filter = tf.random.normal([1, 2, 4, 2], seed=1254)@tf.functiondef conv_model(x):    return tf.nn.conv2d(x, conv_filter, 1, "SAME")input_tensor = tf.ones([1, 2, 3, 4])output_tensor = conv_model(input_tensor)print("Original model outputs:", output_tensor, sep="n")# Try saving the model: it won't work because a tf.function is not trackableexport_dir = "./tmp/"try: tf.saved_model.save(conv_model, export_dir)except ValueError: print(    "Can't save {} object because it's not trackable".format(type(conv_model)))# Now define a trackable object by inheriting from the tf.Module classclass MyModule(tf.Module):    @tf.function    def __call__(self, x): return conv_model(x)# Instantiate the trackable object, and call once to trace-compile a graphmodule_func = MyModule()module_func(input_tensor)tf.saved_model.save(module_func, export_dir)# Restore the model and verify that the outputs are consistentrestored_model = tf.saved_model.load(export_dir)restored_output_tensor = restored_model(input_tensor)print("Restored model outputs:", restored_output_tensor, sep="n")if np.array_equal(output_tensor.numpy(), restored_output_tensor.numpy()):    print("Outputs are consistent :)")else: print("Outputs are NOT consistent :(")

控制台输出:

Original model outputs:tf.Tensor([[[[-2.3629642   1.2904963 ]   [-2.3629642   1.2904963 ]   [-0.02110204  1.3400152 ]]  [[-2.3629642   1.2904963 ]   [-2.3629642   1.2904963 ]   [-0.02110204  1.3400152 ]]]], shape=(1, 2, 3, 2), dtype=float32)Can't save <class 'tensorflow.python.eager.def_function.Function'> objectbecause it's not trackableRestored model outputs:tf.Tensor([[[[-2.3629642   1.2904963 ]   [-2.3629642   1.2904963 ]   [-0.02110204  1.3400152 ]]  [[-2.3629642   1.2904963 ]   [-2.3629642   1.2904963 ]   [-0.02110204  1.3400152 ]]]], shape=(1, 2, 3, 2), dtype=float32)Outputs are consistent :)

因此,您应该尝试按以下方式修改代码:

svmod = tf.saved_model.load(fn) #version 1svmod2 = svmod.prune(feeds=['foo:0'], fetches=['bar:0'])class Exportable(tf.Module):    @tf.function    def __call__(self, model_inputs): return svmod2(model_inputs)svmod2_export = Exportable()svmod2_export(typical_input)    # call once with typical input to trace-compiletf.saved_model.save(svmod2_export, '/tmp/saved_model/')

如果您不想继承自

tf.Module
,则可以替换实例代码,实例化一个
tf.Module
对象并添加
tf.function
方法/可调用属性,如下所示:

to_export = tf.Module()to_export.call = tf.function(conv_model)to_export.call(input_tensor)tf.saved_model.save(to_export, export_dir)restored_module = tf.saved_model.load(export_dir)restored_func = restored_module.call


欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5648469.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-16
下一篇 2022-12-16

发表评论

登录后才能评论

评论列表(0条)

保存