如何将Keras .h5导出到tensorflow .pb?

如何将Keras .h5导出到tensorflow .pb?,第1张

如何将Keras .h5导出到tensorflow .pb?

Keras本身不包括将TensorFlow图导出为协议缓冲区文件的任何方法,但是您可以使用常规TensorFlow实用程序来实现。这是一篇博客文章,解释了如何使用

freeze_graph.py
TensorFlow中包含的实用程序脚本执行 *** 作,这是完成 *** 作的“典型”方式。

但是,我个人觉得必须创建一个检查点,然后运行一个外部脚本来获取模型,但我更喜欢从我自己的Python代码中执行此 *** 作,因此我使用了这样的函数:

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):    """    Freezes the state of a session into a pruned computation graph.    Creates a new computation graph where variable nodes are replaced by    constants taking their current value in the session. The new graph will be    pruned so subgraphs that are not necessary to compute the requested    outputs are removed.    @param session The TensorFlow session to be frozen.    @param keep_var_names A list of variable names that should not be frozen,    or None to freeze all the variables in the graph.    @param output_names Names of the relevant graph outputs.    @param clear_devices Remove the device directives from the graph for better portability.    @return The frozen graph definition.    """    graph = session.graph    with graph.as_default():        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))        output_names = output_names or []        output_names += [v.op.name for v in tf.global_variables()]        input_graph_def = graph.as_graph_def()        if clear_devices: for node in input_graph_def.node:     node.device = ""        frozen_graph = tf.graph_util.convert_variables_to_constants( session, input_graph_def, output_names, freeze_var_names)        return frozen_graph

这是实施的启发

freeze_graph.py
。参数也类似于脚本。
session
是TensorFlow会话对象。
keep_var_names
仅在您希望不冻结某些变量时才需要(例如,对于有状态模型),通常不需要。
output_names
是包含产生所需输出的 *** 作名称的列表。
clear_devices
只需删除任何设备指令即可使图形更具可移植性。因此,对于
model
具有一个输出的典型Keras
,您将执行以下 *** 作:

from keras import backend as K# Create, compile and train model...frozen_graph = freeze_session(K.get_session(),        output_names=[out.op.name for out in model.outputs])

然后,您可以像往常一样使用

tf.train.write_graph
以下命令将图形写入文件:

tf.train.write_graph(frozen_graph, "some_directory", "my_model.pb", as_text=False)


欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/zaji/5650012.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-16
下一篇 2022-12-16

发表评论

登录后才能评论

评论列表(0条)

保存