首页 > 解决方案 > 如何以 protobuf 格式保存 Tensorflow 模型?

问题描述

请帮我解决我的问题。我想将我的神经网络保存为 OpenCV DNN 的 protobuf (pb) 格式。在输入中,我有 3 个文件:.meta、.data、.index。作为输出,我需要 .pb 和 .pbtxt 文件。

代码,例如:

train_data = np.load(TEST_PACK)

tf.reset_default_graph()
convnet = input_data(shape=[None, SIZE, SIZE, 3], name='input')


convnet = conv_2d(convnet, 32, 10, activation='relu')
convnet = max_pool_2d(convnet, 30)

convnet = conv_2d(convnet, 64, 10, activation='relu')
convnet = max_pool_2d(convnet, 10)

convnet = conv_2d(convnet, 128, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)

convnet = conv_2d(convnet, 64, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)

convnet = conv_2d(convnet, 32, 5, activation='relu')
convnet = max_pool_2d(convnet, 5)

convnet = fully_connected(convnet, 1024, activation='relu')
convnet = dropout(convnet, 0.8)

convnet = fully_connected(convnet, 3, activation='softmax')

convnet = regression(convnet, optimizer='adam', learning_rate=LR, loss='categorical_crossentropy', name='targets')


model = tflearn.DNN(convnet, tensorboard_dir='log')

print('model loaded!')

train = train_data[:-500]
test = train_data[-500:]

X = np.array([i[0] for i in train]).reshape(-1,SIZE,SIZE,3)
Y = [i[1] for i in train]

test_x = np.array([i[0] for i in test]).reshape(-1,SIZE,SIZE,3)
test_y = [i[1] for i in test]

model.fit({'input': X}, {'targets': Y}, n_epoch=5, validation_set=({'input': test_x}, {'targets': test_y}), 
          snapshot_step=500, show_metric=True, run_id=MODEL_NAME)

我是神经网络的新手,如果我胡说八道,请道歉))

标签: pythonopencvtensorflowneural-network

解决方案


以下函数将冻结您的模型并创建“.pb”文件。

def freeze_model(sess, logs_path, latest_checkpoint, model, pb_file_name, freeze_pb_file_name):
    """
    :param sess     : tensor-flow session instance which creates the all graph information
    :param logs_path: string
                      directory path where the checkpoint files are stored
    :param latest_checkpoint: string
                              checkpoint file path
    :param model: model instance for extracting the nodes explicitly
    :param pb_file_name: string
                         Name of trainable pb file where the graph and weights will be stored
    :param freeze_pb_file_name: string
                                Name of freeze pb file where the graph and weights will be stored
    """
    print("logs_path =", logs_path)
    tf.train.write_graph(sess.graph.as_graph_def(), logs_path, pb_file_name)
    input_graph_path = os.path.join(logs_path, pb_file_name)
    input_saver_def_path = ""
    input_binary = False
    input_checkpoint_path = latest_checkpoint
    output_graph_path = os.path.join(logs_path, freeze_pb_file_name)
    clear_devices = False
    output_node_names = ",".join(model.nodes)
    restore_op_name = "save/restore_all"
    filename_tensor_name = "save/Const:0"
    initializer_nodes = ""
    freeze_graph.freeze_graph(input_graph_path,
                              input_saver_def_path,
                              input_binary,
                              input_checkpoint_path,
                              output_node_names,
                              restore_op_name,
                              filename_tensor_name,
                              output_graph_path,
                              clear_devices,
                              initializer_nodes)

model.nodes是张量节点的列表。我认为您可以创建output_node_names一个空字符串:output_node_names = ""


推荐阅读