tensorflow - 我们如何将 keras 模型 .h5 文件转换为 tensorflow 保存的模型 (.pb)
问题描述
我有一个经过训练的 keras 模型并将其保存为 h5 格式。我想在 google cloud ml 引擎上托管这个模型进行预测。如何将 keras 模型 .h5 文件转换为保存的模型。
解决方案
我看到这个代码的帖子(虽然我在许多其他帖子中看到过这个解决方案)是这样的:https ://www.dlology.com/blog/how-to-convert-trained-keras-model-to-tensorflow -并进行预测/
import tensorflow as tf
from keras import backend as K
# This line must be executed before loading Keras model.
K.set_learning_phase(0)
from keras.models import load_model
model = load_model('./model/keras_model.h5')
def freeze_session(session, keep_var_names=None, output_names=None,clear_devices=True):
"""
Freezes the state of a session into a pruned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
pruned so subgraphs that are not necessary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
# Graph -> GraphDef ProtoBuf
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
frozen_graph = freeze_session(K.get_session(),
output_names=[out.op.name for out in model.outputs])
tf.train.write_graph(frozen_graph, "model", "tf_model.pb", as_text=False)
推荐阅读
- opencl - OpenCL:为输入和输出创建缓冲区时如何指定数组的大小
- python - How to split a list of tuples based on the minimum value in each tuple?
- css - 如何正确迁移到具有相同按钮颜色和字体的 Bootstrap5?
- reactjs - Strapi 富文本内容未在网页上正确显示
- events - Quasar 框架 v-on:input 什么都不做
- google-visualization - 单击时控件(类别过滤器)未正确更新表格图表
- javascript - 反应类继承
- angular - “ng”命令只能通过管理员 cmd 执行
- constraints - 如何添加约束以使只有一个非整数值?
- java - GlassFish 5.1.0 错误,因为 dcs 为空