首页 > 解决方案 > Tensorflow:从图中删除节点

问题描述

我正在尝试从图中删除一些节点并将其保存在 .pb

只能将需要的节点添加到新mod_graph_def图形中,但是图形在其他节点输入中仍然有一些对已删除节点的引用,但我无法修改节点的输入的问题:

def delete_ops_from_graph():
    with open(input_model_filepath, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    nodes = []
    for node in graph_def.node:
        if 'Neg' in node.name:
            print('Drop', node.name)
        else:
            nodes.append(node)

    mod_graph_def = tf.GraphDef()
    mod_graph_def.node.extend(nodes)

    # The problem that graph still have some references to deleted node in other nodes inputs
    for node in mod_graph_def.node:
        inp_names = []
        for inp in node.input:
            if 'Neg' in inp:
                pass
            else:
                inp_names.append(inp)

        node.input = inp_names # TypeError: Can't set composite field

    with open(output_model_filepath, 'wb') as f:
        f.write(mod_graph_def.SerializeToString())

标签: pythontensorflow

解决方案


def delete_ops_from_graph():
    with open(input_model_filepath, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # Delete nodes
    nodes = []
    for node in graph_def.node:
        if 'Neg' in node.name:
            print('Drop', node.name)
        else:
            nodes.append(node)

    mod_graph_def = tf.GraphDef()
    mod_graph_def.node.extend(nodes)

    # Delete references to deleted nodes
    for node in mod_graph_def.node:
        inp_names = []
        for inp in node.input:
            if 'Neg' in inp:
                pass
            else:
                inp_names.append(inp)

        del node.input[:]
        node.input.extend(inp_names)

    with open(output_model_filepath, 'wb') as f:
        f.write(mod_graph_def.SerializeToString())

推荐阅读