tensorflow - tf.keras.layers.BatchNormalization:读取冻结图时出现 ValueError
问题描述
我使用的模型包含 tf.keras.applications.MobileNetV2 并且冻结图有问题。
我发现它与 BatchNormalization 层有关,因此编写了一个更简单的测试程序 TESTE_PART1.py,只有 1 层。我使用 freeze_graph 工具创建 .pb 文件。然后另一个 python 程序 TESTE_PART2.py 使用 tf.import_graph_def 读取 .pb。
错误是:
Traceback (most recent call last):
File "/home/daniel/sistema/anaconda3/envs/tensorflow_1.13/lib/python3.6/site-packages/tensorflow/python/framework/importer.py", line 426, in import_graph_def
graph._c_graph, serialized, options) # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input 0 of node batch_normalization_v1/cond/ReadVariableOp/Switch was passed float from batch_normalization_v1/moving_mean:0 incompatible with expected resource.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/daniel/sistema/anaconda3/envs/tensorflow_1.13/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/home/daniel/sistema/anaconda3/envs/tensorflow_1.13/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/daniel/comp/armis/Plates/pysrc/cornerdetect/teste_part2.py", line 9, in <module>
tf.import_graph_def(od_graph_def, name='')
File "/home/daniel/sistema/anaconda3/envs/tensorflow_1.13/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/home/daniel/sistema/anaconda3/envs/tensorflow_1.13/lib/python3.6/site-packages/tensorflow/python/framework/importer.py", line 430, in import_graph_def
raise ValueError(str(e))
ValueError: Input 0 of node batch_normalization_v1/cond/ReadVariableOp/Switch was passed float from batch_normalization_v1/moving_mean:0 incompatible with expected resource.
我还尝试使用此处所述的 freeze_session 方法在 python 中创建 .pb How to export Keras .h5 to tensorflow .pb? 结果是一样的。
该错误抱怨不兼容的类型。我修改了 TESTE_PART1.py 以打印 2 个操作的 dtypes,两者都是“资源”:
O1 INPUT Tensor("batch_normalization_v1/moving_mean:0", shape=(), dtype=resource)
O1 INPUT Tensor("batch_normalization_v1/cond/pred_id:0", shape=(), dtype=bool)
O2 OUTPUT Tensor("batch_normalization_v1/moving_mean:0", shape=(), dtype=resource)
我用 tf 1.12 和 1.13 对其进行了测试。操作系统是 Linux 4.4.0-148-generic #174-Ubuntu SMP Tue May 7 12:20:14 UTC 2019 x86_64 x86_64 x86_64 GNU/Linux
TESTE_PART1.py
import tensorflow as tf
l1 = tf.keras.layers.BatchNormalization(input_shape=(100,))
model = tf.keras.models.Sequential([l1])
print("output name", model.output.op.name)
print("input name", model.input.op.name)
# output name batch_normalization_v1/batchnorm/add_1
# input name batch_normalization_v1_input
saver = tf.train.Saver()
sess = tf.keras.backend.get_session()
saver.save(sess, "teste.saver_export")
op1 = sess.graph.get_operation_by_name("batch_normalization_v1/cond/ReadVariableOp/Switch")
for input in op1.inputs:
print("O1 INPUT", input)
op2 = sess.graph.get_operation_by_name("batch_normalization_v1/moving_mean")
for output in op2.outputs:
print("O2 OUTPUT", output)
freeze_graph --input_meta_graph=teste.saver_export.meta --input_checkpoint=teste.saver_export --output_graph=teste.freeze_graph.pb --output_node_names="batch_normalization_v1/batchnorm/add_1" --input_binary=true
$ ls -l teste*
-rw-rw-r-- 1 daniel daniel 6990 jun 1 18:35 teste.freeze_graph.pb
-rw-rw-r-- 1 daniel daniel 1600 jun 1 18:34 teste.saver_export.data-00000-of-00001
-rw-rw-r-- 1 daniel daniel 239 jun 1 18:34 teste.saver_export.index
-rw-rw-r-- 1 daniel daniel 28006 jun 1 18:34 teste.saver_export.meta
TESTE_PART2.py
import tensorflow as tf
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with open("teste.freeze_graph.pb", 'rb') as fd:
serialized_graph = fd.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='') # here
解决方案
推荐阅读
- elixir - 如何在函数的参数中解构/解码 JSON?
- javascript - 如何使用实用功能来检测操作系统和浏览器的反应?
- excel - Excel中的作物订单调度
- arduino - 如何让 GSM 调制解调器(例如 SIM900A)进入睡眠模式?
- git - 如何识别 git 提交之间哪些文件已更改
- c - 如何创建共享范围的文件?
- python - 我可以让多个控制台脚本名称指向一个 PyPI 包的 setup.py 中的同一个脚本吗?
- react-native - 我可以在反应导航中使用变量作为 routeName 吗?
- javascript - 是否可以跨两个不同的 iframe 复制鼠标和键盘事件?
- python - 如何创建一个类似模板的函数,它以一堆代码作为参数?