c++ - 如何在 C++ 中使用 Keras SavedModel
问题描述
我有一个训练有素的经典 CNN(预训练移动网络)用于图像分类。我现在想从 c++ 中使用这个模型。据我了解,我需要创建一个模型库,它可以接受输入并返回其输出。我将模型保存为格式.pb
(SavedModel)。
我已经尝试过CppFlow,错误表明它无法读取我的模型。我认为这是由于与 TF 2.0 不兼容造成的。
我也有SavedModel
工作的命令行界面,但我不知道如何在我的 cpp 应用程序中使用它。
我想知道如何构建我的模型库并使用该库以便它可以即时进行预测。任何指导都会有所帮助。如果需要任何其他信息,请告诉我。
解决方案
在 C++ 中使用 keras 模型的一种方法是将其转换为 TensorFlow.pb
格式。我刚刚编写了一个用于执行此操作的脚本,如下所示。
用法:python script.py keras_model.hdf5
它输出 tensorflow 模型作为输入文件名附加.pb
.
然后您可以使用TF C++ api来读取模型并进行推理。在 C++ TF 中使用图像识别模型标记图像的详细示例位于此处。
另一种选择 - 您可以通过从 C++ 调用 Python API 直接使用 Keras,这并不难,有独立的 Python,它是静态编译的,这意味着根本没有 dll/共享库依赖项,因此 Python 解释器可以完全编译成 C++ 单个二进制文件. Internet 上也有许多库可以帮助您轻松地从 C++ 运行 Python。
import sys, os
from keras import backend as K
from keras.models import load_model
import tensorflow as tf
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a pruned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
pruned so subgraphs that are not necessary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
# Graph -> GraphDef ProtoBuf
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
if len(sys.argv) <= 1:
print('Usage: python script.py keras_model.hdf5')
sys.exit(0)
else:
ifname = sys.argv[1]
model = load_model(ifname)
frozen_graph = freeze_session(
K.get_session(),
output_names = [out.op.name for out in model.outputs],
)
tf.io.write_graph(frozen_graph, os.path.dirname(ifname), ifname + '.pb', as_text = False)
推荐阅读
- rust - 有没有办法指定多个版本的依赖项对我的库有效?
- java - http 状态 500 - 内部服务器错误(拒绝访问)
- javascript - 将用户重定向到不同的页面取决于 React 中的 userType
- python - 单独调用规则的变量并为特定规则添加独立环境
- angular - 如何使用 Leaflet 和 OpenStreetMap 将我的地图视图重置回我在 Ionic 中的位置?
- logging - 带有日志记录的简单 Alexa 技能
- c# - 从命令行部署到 Azure 会抛出 Http Status 500,而从 VS'17 部署到 Azure 会抛出 Http Status 500
- java - 如何获得与地图条目中列出的日期相关的值的累积
- json - 向 IAM 角色添加内联策略
- multithreading - 使用 Pytest 在多台机器上运行多进程、多线程测试