python - 如何以 protobuf 格式保存 Tensorflow 模型?
问题描述
请帮我解决我的问题。我想将我的神经网络保存为 OpenCV DNN 的 protobuf (pb) 格式。在输入中,我有 3 个文件:.meta、.data、.index。作为输出,我需要 .pb 和 .pbtxt 文件。
代码,例如:
train_data = np.load(TEST_PACK)
tf.reset_default_graph()
convnet = input_data(shape=[None, SIZE, SIZE, 3], name='input')
convnet = conv_2d(convnet, 32, 10, activation='relu')
convnet = max_pool_2d(convnet, 30)
convnet = conv_2d(convnet, 64, 10, activation='relu')
convnet = max_pool_2d(convnet, 10)
convnet = conv_2d(convnet, 128, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)
convnet = conv_2d(convnet, 64, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)
convnet = conv_2d(convnet, 32, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)
convnet = fully_connected(convnet, 1024, activation='relu')
convnet = dropout(convnet, 0.8)
convnet = fully_connected(convnet, 3, activation='softmax')
convnet = regression(convnet, optimizer='adam', learning_rate=LR, loss='categorical_crossentropy', name='targets')
model = tflearn.DNN(convnet, tensorboard_dir='log')
print('model loaded!')
train = train_data[:-500]
test = train_data[-500:]
X = np.array([i[0] for i in train]).reshape(-1,SIZE,SIZE,3)
Y = [i[1] for i in train]
test_x = np.array([i[0] for i in test]).reshape(-1,SIZE,SIZE,3)
test_y = [i[1] for i in test]
model.fit({'input': X}, {'targets': Y}, n_epoch=5, validation_set=({'input': test_x}, {'targets': test_y}),
snapshot_step=500, show_metric=True, run_id=MODEL_NAME)
我是神经网络的新手,如果我胡说八道,请道歉))
解决方案
以下函数将冻结您的模型并创建“.pb”文件。
def freeze_model(sess, logs_path, latest_checkpoint, model, pb_file_name, freeze_pb_file_name):
"""
:param sess : tensor-flow session instance which creates the all graph information
:param logs_path: string
directory path where the checkpoint files are stored
:param latest_checkpoint: string
checkpoint file path
:param model: model instance for extracting the nodes explicitly
:param pb_file_name: string
Name of trainable pb file where the graph and weights will be stored
:param freeze_pb_file_name: string
Name of freeze pb file where the graph and weights will be stored
"""
print("logs_path =", logs_path)
tf.train.write_graph(sess.graph.as_graph_def(), logs_path, pb_file_name)
input_graph_path = os.path.join(logs_path, pb_file_name)
input_saver_def_path = ""
input_binary = False
input_checkpoint_path = latest_checkpoint
output_graph_path = os.path.join(logs_path, freeze_pb_file_name)
clear_devices = False
output_node_names = ",".join(model.nodes)
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
initializer_nodes = ""
freeze_graph.freeze_graph(input_graph_path,
input_saver_def_path,
input_binary,
input_checkpoint_path,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph_path,
clear_devices,
initializer_nodes)
model.nodes
是张量节点的列表。我认为您可以创建output_node_names
一个空字符串:output_node_names = ""
推荐阅读
- javascript - 使用 sanity cms 的特定帖子
- node.js - npx create-react-app -> 引擎“节点”与此模块不兼容。预期版本 ">= 10.x"。得到“8.16.0”——Arch Linux
- python - Bluez/Python 缓冲导致蓝牙连接延迟
- cuda - 难以使用 atomicMin 在矩阵中找到最小值
- javascript - 用空合并缩短三元
- python - 如何用 ppid = 1 杀死僵尸进程?
- python-3.x - 在 Union 中处理具有值的参数的 Pythonic 方式
- goland - Goland 搜索功能未索引项目中的所有文件
- javascript - 如何使用javascript在刷新时保存切换按钮状态
- visual-studio-code - 在 VScode 选项卡中禁用问题徽章