首页 > 解决方案 > 将 keras (.h5) 模型转换为 tf 冻结图 (.pd) 格式后的精度损失

问题描述

我正在使用 tensorflow 1.12.0 和 keras 2.1.4(我一开始使用的是 2.2.4,问题是一样的。)
我正在处理一些模型传输任务,其中一个步骤是转换 keras 模型(.h5)到 tensorflow 冻结图形格式 (.pb)。加载 .h5 并在 keras 中测试时模型准确度为 98%,但在加载冻结图 (.pb) 并在 tensorflow 中测试时仅保持 10%。
我通过以下代码将 h5 转移到冻结图。

import tensorflow.keras as keras
import tensorflow as tf
import os
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    Freezes the state of a session into a pruned computation graph.
    Creates a new computation graph where variable nodes are replaced by
    constants taking their current value in the session. The new graph will be
    pruned so subgraphs that are not necessary to compute the requested
    outputs are removed.
    @param session The TensorFlow session to be frozen.
    @param keep_var_names A list of variable names that should not be frozen,
                          or None to freeze all the variables in the graph.
    @param output_names Names of the relevant graph outputs.
    @param clear_devices Remove the device directives from the graph for better portability.
    @return The frozen graph definition.
    """
    session.run(tf.initialize_all_variables())
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ''
        frozen_graph = tf.graph_util.convert_variables_to_constants(
            session, input_graph_def, output_names, freeze_var_names)
        return frozen_graph
 
 
#keras h5 file
input_path='./'
input_file = './CNN_Mnist.h5'
weight_file_path = os.path.join(input_path, input_file)
output_graph_name = 'frozen_graph.pb'

#Load Module
keras.backend.set_learning_phase(0)
h5_model = keras.models.load_model(weight_file_path)

frozen_graph = freeze_session(keras.backend.get_session(), output_names=[out.op.name for out in h5_model.outputs])
tf.train.write_graph(frozen_graph, input_path, output_graph_name, as_text=False)
print('finish!')

然后我试图在张量流中加载冻结图并再次评估它,但我得到了非常糟糕的结果。这是我的评估代码。

mport os
import argparse
import shutil
import tensorflow as tf
import numpy as np

from tensorflow.python.platform import gfile
import tensorflow.contrib.decent_q
def graph_eval(input_graph_def, input_node, output_node):

    # MNIST dataset    
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_test = (x_test/255.0).astype(np.float32)
    x_test = np.reshape(x_test, [-1, 28, 28, 1])
    y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)

    tf.import_graph_def(input_graph_def,name = '')

    # Get input placeholders & tensors
    images_in = tf.get_default_graph().get_tensor_by_name(input_node+':0')
    labels = tf.placeholder(tf.int32,shape = [None,10])

    # get output tensors
    logits = tf.get_default_graph().get_tensor_by_name(output_node+':0')

    # top 5 and top 1 accuracy
    in_top5 = tf.nn.in_top_k(predictions=logits, targets=tf.argmax(labels, 1), k=5)
    in_top1 = tf.nn.in_top_k(predictions=logits, targets=tf.argmax(labels, 1), k=1)
    top5_acc = tf.reduce_mean(tf.cast(in_top5, tf.float32))
    top1_acc = tf.reduce_mean(tf.cast(in_top1, tf.float32))
    
    # Create the Computational graph
    with tf.Session() as sess:
        
        sess.run(tf.initializers.global_variables())
      
        feed_dict={images_in: x_test, labels: y_test}
        t5_acc,t1_acc = sess.run([top5_acc,top1_acc], feed_dict)
    
        print (' Top 1 accuracy with validation set: {:1.4f}'.format(t1_acc))
        print (' Top 5 accuracy with validation set: {:1.4f}'.format(t5_acc))

    print ('FINISHED!')

#define arguments
graph='./frozen_graph_opt.pb'#output_graph #'./frozen_graph.pb'
input_node='conv2d_1_input_5'
output_node='dense_2_3/Softmax'
####
input_graph_def = tf.Graph().as_graph_def()
input_graph_def.ParseFromString(tf.gfile.FastGFile(graph, "rb").read())
graph_eval(input_graph_def,input_node,output_node)

训练和评估中的数据预处理都使用相同的功能。有人可以帮我指出在将 .h5 传输到 .pd 文件时应该检查哪些部分或应该注意什么吗?
或者我的评估代码有什么问题吗?

这是我在 Netron 中阅读的模型。
keras .h5 模型 keras .h5 模型

Tensorflow .pb 模型 Tensorflow .pb 模型

标签: pythontensorflowmachine-learningkerasdeep-learning

解决方案


推荐阅读