python - 将 keras (.h5) 模型转换为 tf 冻结图 (.pd) 格式后的精度损失
问题描述
我正在使用 tensorflow 1.12.0 和 keras 2.1.4(我一开始使用的是 2.2.4,问题是一样的。)
我正在处理一些模型传输任务,其中一个步骤是转换 keras 模型(.h5)到 tensorflow 冻结图形格式 (.pb)。加载 .h5 并在 keras 中测试时模型准确度为 98%,但在加载冻结图 (.pb) 并在 tensorflow 中测试时仅保持 10%。
我通过以下代码将 h5 转移到冻结图。
import tensorflow.keras as keras
import tensorflow as tf
import os
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a pruned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
pruned so subgraphs that are not necessary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
session.run(tf.initialize_all_variables())
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ''
frozen_graph = tf.graph_util.convert_variables_to_constants(
session, input_graph_def, output_names, freeze_var_names)
return frozen_graph
#keras h5 file
input_path='./'
input_file = './CNN_Mnist.h5'
weight_file_path = os.path.join(input_path, input_file)
output_graph_name = 'frozen_graph.pb'
#Load Module
keras.backend.set_learning_phase(0)
h5_model = keras.models.load_model(weight_file_path)
frozen_graph = freeze_session(keras.backend.get_session(), output_names=[out.op.name for out in h5_model.outputs])
tf.train.write_graph(frozen_graph, input_path, output_graph_name, as_text=False)
print('finish!')
然后我试图在张量流中加载冻结图并再次评估它,但我得到了非常糟糕的结果。这是我的评估代码。
mport os
import argparse
import shutil
import tensorflow as tf
import numpy as np
from tensorflow.python.platform import gfile
import tensorflow.contrib.decent_q
def graph_eval(input_graph_def, input_node, output_node):
# MNIST dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_test = (x_test/255.0).astype(np.float32)
x_test = np.reshape(x_test, [-1, 28, 28, 1])
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
tf.import_graph_def(input_graph_def,name = '')
# Get input placeholders & tensors
images_in = tf.get_default_graph().get_tensor_by_name(input_node+':0')
labels = tf.placeholder(tf.int32,shape = [None,10])
# get output tensors
logits = tf.get_default_graph().get_tensor_by_name(output_node+':0')
# top 5 and top 1 accuracy
in_top5 = tf.nn.in_top_k(predictions=logits, targets=tf.argmax(labels, 1), k=5)
in_top1 = tf.nn.in_top_k(predictions=logits, targets=tf.argmax(labels, 1), k=1)
top5_acc = tf.reduce_mean(tf.cast(in_top5, tf.float32))
top1_acc = tf.reduce_mean(tf.cast(in_top1, tf.float32))
# Create the Computational graph
with tf.Session() as sess:
sess.run(tf.initializers.global_variables())
feed_dict={images_in: x_test, labels: y_test}
t5_acc,t1_acc = sess.run([top5_acc,top1_acc], feed_dict)
print (' Top 1 accuracy with validation set: {:1.4f}'.format(t1_acc))
print (' Top 5 accuracy with validation set: {:1.4f}'.format(t5_acc))
print ('FINISHED!')
#define arguments
graph='./frozen_graph_opt.pb'#output_graph #'./frozen_graph.pb'
input_node='conv2d_1_input_5'
output_node='dense_2_3/Softmax'
####
input_graph_def = tf.Graph().as_graph_def()
input_graph_def.ParseFromString(tf.gfile.FastGFile(graph, "rb").read())
graph_eval(input_graph_def,input_node,output_node)
训练和评估中的数据预处理都使用相同的功能。有人可以帮我指出在将 .h5 传输到 .pd 文件时应该检查哪些部分或应该注意什么吗?
或者我的评估代码有什么问题吗?
这是我在 Netron 中阅读的模型。
keras .h5 模型
keras .h5 模型
Tensorflow .pb 模型 Tensorflow .pb 模型
解决方案
推荐阅读
- python - 如何使用 python unittest 在模拟对象中模拟列表?
- python - Odoo - 计算字段导入
- angular - “Angulartics2GoogleAnalytics”类型上不存在属性“startTracking”
- python - Python 3 urllib Discord 与 Slack Bot
- junit - Junit双引号比较
- google-apps-script - 谷歌应用程序脚本如何在另一个工作表中复制条件格式
- unity3d - 为什么当玩家死亡时粒子不显示,它自己工作正常,但当它击中受伤的物体时不显示?(我正在做一个 2D 游戏)
- spring-boot - 无法使用@JmsListener 连接到远程队列
- unity3d - 如何将文件从 Unity 场景发送给我的朋友
- python - Libvirt Python SSH 连接超时