Keras本身不包括将TensorFlow图导出为协议缓冲区文件的任何方法,但是您可以使用常规TensorFlow实用程序来实现。这是一篇博客文章,解释了如何使用
freeze_graph.pyTensorFlow中包含的实用程序脚本执行此 *** 作,这是完成 *** 作的“典型”方式。
但是,我个人觉得必须创建一个检查点,然后运行一个外部脚本来获取模型,但我更喜欢从我自己的Python代码中执行此 *** 作,因此我使用了这样的函数:
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True): """ Freezes the state of a session into a pruned computation graph. Creates a new computation graph where variable nodes are replaced by constants taking their current value in the session. The new graph will be pruned so subgraphs that are not necessary to compute the requested outputs are removed. @param session The TensorFlow session to be frozen. @param keep_var_names A list of variable names that should not be frozen, or None to freeze all the variables in the graph. @param output_names Names of the relevant graph outputs. @param clear_devices Remove the device directives from the graph for better portability. @return The frozen graph definition. """ graph = session.graph with graph.as_default(): freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) output_names = output_names or [] output_names += [v.op.name for v in tf.global_variables()] input_graph_def = graph.as_graph_def() if clear_devices: for node in input_graph_def.node: node.device = "" frozen_graph = tf.graph_util.convert_variables_to_constants( session, input_graph_def, output_names, freeze_var_names) return frozen_graph
这是实施的启发
freeze_graph.py。参数也类似于脚本。
session是TensorFlow会话对象。
keep_var_names仅在您希望不冻结某些变量时才需要(例如,对于有状态模型),通常不需要。
output_names是包含产生所需输出的 *** 作名称的列表。
clear_devices只需删除任何设备指令即可使图形更具可移植性。因此,对于
model具有一个输出的典型Keras
,您将执行以下 *** 作:
from keras import backend as K# Create, compile and train model...frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in model.outputs])
然后,您可以像往常一样使用
tf.train.write_graph以下命令将图形写入文件:
tf.train.write_graph(frozen_graph, "some_directory", "my_model.pb", as_text=False)
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)