python - Tensorflow:从图中删除节点
问题描述
我正在尝试从图中删除一些节点并将其保存在 .pb
只能将需要的节点添加到新mod_graph_def
图形中,但是图形在其他节点输入中仍然有一些对已删除节点的引用,但我无法修改节点的输入的问题:
def delete_ops_from_graph():
with open(input_model_filepath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
nodes = []
for node in graph_def.node:
if 'Neg' in node.name:
print('Drop', node.name)
else:
nodes.append(node)
mod_graph_def = tf.GraphDef()
mod_graph_def.node.extend(nodes)
# The problem that graph still have some references to deleted node in other nodes inputs
for node in mod_graph_def.node:
inp_names = []
for inp in node.input:
if 'Neg' in inp:
pass
else:
inp_names.append(inp)
node.input = inp_names # TypeError: Can't set composite field
with open(output_model_filepath, 'wb') as f:
f.write(mod_graph_def.SerializeToString())
解决方案
def delete_ops_from_graph():
with open(input_model_filepath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Delete nodes
nodes = []
for node in graph_def.node:
if 'Neg' in node.name:
print('Drop', node.name)
else:
nodes.append(node)
mod_graph_def = tf.GraphDef()
mod_graph_def.node.extend(nodes)
# Delete references to deleted nodes
for node in mod_graph_def.node:
inp_names = []
for inp in node.input:
if 'Neg' in inp:
pass
else:
inp_names.append(inp)
del node.input[:]
node.input.extend(inp_names)
with open(output_model_filepath, 'wb') as f:
f.write(mod_graph_def.SerializeToString())
推荐阅读
- c++ - c++ std::string 在不同的编译器和标准库之间是二进制兼容的吗?
- r - 将双精度转换为日期
- gradle - 如何懒惰地设置任务的属性?
- javascript - data() 不是用于查询 firebase 的函数
- c - 我试图找到四个 int 中最大的一个。与 C. 但程序不工作
- clojure - 如何让 rebel-readline 显示更多候选人
- lambda-calculus - 什么是 J 运算符,它与 call/cc 相同吗?
- gitlab - Gitlab 代码质量未找到要写入的文件
- c# - 如何使用 EFCore 3.1 从 DbContext 中选择特定的 DbSet?
- elasticsearch-7 - elastic4s中的scriptScoreQuery实现