python - 从预训练模型进行微调后,TensorFlow 模型中的输出节点名称丢失
问题描述
我按照https://tensorflow-object-detection-api-tutorial.readthedocs.io上的教程来微调预训练模型以检测图像中的新对象。预训练模型是ssd_inception_v2_coco。
经过几千步后,我成功地训练和评估了模型,损失从 26 降至 1。但是,我未能使用以下代码创建冻结模型:
#this code runs in model dir
import tensorflow as tf
#make .pb file from model at step 1000
saver = tf.train.import_meta_graph(
'./model.ckpt-1000.meta', clear_devices=True)
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
sess = tf.Session()
saver.restore(sess, "./model.ckpt-1000")
#node names
i=0
for n in tf.get_default_graph().as_graph_def().node:
print(n.name,i);
i+=1
#end for
print("total:",i);
output_node_names=[
"detection_boxes","detection_classes",
"detection_scores","num_detections"
];
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess,input_graph_def,output_node_names);
#save to .pb file
output_graph="./model.pb"
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString());
#end with
sess.close();
错误是:
微调后的模型似乎丢失了输出节点名称。原始预训练模型中有这些输出节点名称(将上面代码中的检查点文件更改为原始训练模型中的):detection_boxes、detection_classes、detection_scores 和 num_detections。输出节点名称与原始节点名称完全相同,这是它们的索引(来自上面的节点名称“for”循环):
我的问题是如何保留原始预训练模型的输出节点名称?节点名称在代码中定义,但这里没有代码,只有一些配置和文件“train.py”。
PS。在total_loss后面有个叫summary_op的东西,不知道是不是输出(?):
解决方案
为了拥有“ image_tensor ”(输入)和其他输出节点名称“ detection_boxes ”、“ detection_classes ”、“ detection_scores ”、“ num_detections ”,请使用名为“ export_inference_graph.py ”的tensorflow/models/research/object_detection中的实用程序脚本. 该脚本甚至优化了冻结图(冻结模型)以进行推理。根据我的测试模型,节点数量从 26,000 个减少到 5,000 个;这对推理速度很有帮助。
这是 export_inference_graph.py 的链接: https ://github.com/tensorflow/models/blob/0558408514dacf2fe2860cd72ac56cbdf62a24c0/research/object_detection/export_inference_graph.py
如何运行:
#bash command
python3 export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path PATH_TO_PIPELINE.config \
--trained_checkpoint_prefix PATH_TO/model.ckpt-NUMBER \
--output_directory PATH_TO_NEW_DIR
有问题的 .pb 创建代码仅适用于从头创建并手动定义节点名称的模型,用于从 TensorFlow Model Zoo https://github.com/tensorflow/models/下载的预训练模型微调的模型检查点blob/master/research/object_detection/g3doc/detection_model_zoo.md,它不会工作!
推荐阅读
- mongodb - mongodb中可变嵌套文档的变化
- c# - C# 中的 Spliterator 等价物是什么?
- http - ejabberd/XMPP 聊天服务器中的混合内容错误
- c++ - 单引号作为大量分隔符?
- django - 为 Django 模型字段分配友好名称
- javascript - 提交时 JavaScript 表单验证时页面重置
- c# - Entity Framework 6 - Uncached data with dependency injection
- reactjs - 为什么我无法在 create-react-app 中初始化 git?
- java - 从 GitHub 转贴:Spring Boot 应用程序无法启动并出现错误“java.lang.NoSuchMethodError
- java - 如何为 BBP 算法做 base 16 来获得 PI