python - Tensorflow 2.1/Keras - 尝试冻结图时出现“输出节点不在图中”错误
问题描述
我正在尝试保存使用 Keras 创建并保存为 .h5 文件的模型,但每次尝试运行 freeze_session 函数时都会收到此错误消息:output_node/Identity is not in graph
这是我的代码(我使用的是 Tensorflow 2.1.0):
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a pruned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
pruned so subgraphs that are not necessary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.compat.v1.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.compat.v1.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
session, input_graph_def, output_names, freeze_var_names)
return frozen_graph
model=kr.models.load_model("model.h5")
model.summary()
# inputs:
print('inputs: ', model.input.op.name)
# outputs:
print('outputs: ', model.output.op.name)
#layers:
layer_names=[layer.name for layer in model.layers]
print(layer_names)
哪个打印:
inputs: input_node
outputs: output_node/Identity
['input_node', 'conv2d_6', 'max_pooling2d_6', 'conv2d_7', 'max_pooling2d_7', 'conv2d_8', 'max_pooling2d_8', 'flatten_2', 'dense_4', 'dense_5', 'output_node']
正如预期的那样(与我在训练后保存的模型中相同的层名称和输出)。
然后我尝试调用 freeze_session 函数并保存生成的冻结图:
frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in model.outputs])
write_graph(frozen_graph, './', 'graph.pbtxt', as_text=True)
write_graph(frozen_graph, './', 'graph.pb', as_text=False)
但我收到此错误:
AssertionError Traceback (most recent call last)
<ipython-input-4-1848000e99b7> in <module>
----> 1 frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in model.outputs])
2 write_graph(frozen_graph, './', 'graph.pbtxt', as_text=True)
3 write_graph(frozen_graph, './', 'graph.pb', as_text=False)
<ipython-input-2-3214992381a9> in freeze_session(session, keep_var_names, output_names, clear_devices)
24 node.device = ""
25 frozen_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
---> 26 session, input_graph_def, output_names, freeze_var_names)
27 return frozen_graph
c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\util\deprecation.py in new_func(*args, **kwargs)
322 'in a future version' if date is None else ('after %s' % date),
323 instructions)
--> 324 return func(*args, **kwargs)
325 return tf_decorator.make_decorator(
326 func, new_func, 'deprecated',
c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\framework\graph_util_impl.py in convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist, variable_names_blacklist)
275 # This graph only includes the nodes needed to evaluate the output nodes, and
276 # removes unneeded nodes like those involved in saving and assignment.
--> 277 inference_graph = extract_sub_graph(input_graph_def, output_node_names)
278
279 # Identify the ops in the graph.
c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\util\deprecation.py in new_func(*args, **kwargs)
322 'in a future version' if date is None else ('after %s' % date),
323 instructions)
--> 324 return func(*args, **kwargs)
325 return tf_decorator.make_decorator(
326 func, new_func, 'deprecated',
c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\framework\graph_util_impl.py in extract_sub_graph(graph_def, dest_nodes)
195 name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
196 graph_def)
--> 197 _assert_nodes_are_present(name_to_node, dest_nodes)
198
199 nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name)
c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\framework\graph_util_impl.py in _assert_nodes_are_present(name_to_node, nodes)
150 """Assert that nodes are present in the graph."""
151 for d in nodes:
--> 152 assert d in name_to_node, "%s is not in graph" % d
153
154
**AssertionError: output_node/Identity is not in graph**
我已经尝试过,但我真的不知道如何解决这个问题,所以任何帮助将不胜感激。
解决方案
如果您使用 TensorFlow 2.x 版,请添加:
tf.compat.v1.disable_eager_execution()
这应该有效。我没有检查生成的 pb 文件,但它应该可以工作。
反馈表示赞赏。
编辑:但是,例如,在这个线程之后,TF1 和 TF2 pb 文件是根本不同的。我的解决方案可能无法正常工作或实际上创建了一个 TF1 pb 文件。
如果你然后遇到
RuntimeError:尝试使用已关闭的会话。
这可以通过重新启动内核来解决。使用上面的线,您只有一枪。
推荐阅读
- javascript - 覆盖 Firestore 中的集合
- pyspark - 如何使动态查询过滤器在 pyspark 中运行?
- sql-server - 如何在 XML SQL SERVER 中转义单引号
- c - 动态分配指针数组(K&R 练习 5-13)
- python - 需要将有关人员的信息存储在文件中并提取特定信息(python)
- python - 如何使用 python 3 在 for 循环中创建新数组?
- json - 将 Shopify API 与 Google Sheet FetchUrl 应用程序一起使用
- command-line - ng:尽管安装了@angular/cli,但找不到命令
- angular - 如何更改谷歌地图中的地图类型ID以及如何删除工具提示
- amazon-web-services - 使用gzip压缩卸载时如何卸载csv文件类型?