首页 > 解决方案 > 通过查询内联替换 AST 节点

问题描述

给定输入,例如:

query = ("class_name", "function_name", "arg_name")

如何用其他提供的内容替换找到的内容node

前阶段的解析示例:

class Foo(object):
    def f(self, g: str = "foo"): pass

解析后阶段的示例:

class Foo(object):
    def f(self, g: int = 5): pass

给定以下对假设函数的调用:

replace_ast_node(
    query=("Foo", "f", "g"),
    node=ast.parse("class Foo(object):\n    def f(self, g: str = 'foo'): pass"),
    # Use `AnnAssign` over `arg`; as `defaults` is higher in the `FunctionDef`
    replace_with=AnnAssign(
        annotation=Name(ctx=Load(), id="int"),
        simple=1,
        target=Name(ctx=Store(), id="g"),
        value=Constant(kind=None, value=5),
    ),
)

我已经拼凑了一个简单的解决方案,用于查找带有查询列表的节点("Foo", "f", "g"),它具有为任何可以引用的内容工作的额外好处def Foo(): def f(): def g():,以及来自to的解析器/发射器。但我无法弄清楚这个阶段;是否按顺序遍历?- 那么我是否应该不断遍历、附加当前名称并检查当前位置是否是完整的查询字符串?- 我觉得我缺少一些干净的解决方案……</p> argAnnAssignast.NodeTransformer

标签: pythonpython-3.xannotationsabstract-syntax-treedefault-arguments

解决方案


我决定把它分成两个问题。首先,让每个节点都知道它在宇宙中的位置:

def annotate_ancestry(node):
    """
    Look to your roots. Find the child; find the parent.
    Sets _location attribute to every child node.

    :param node: AST node. Will be annotated in-place.
    :type node: ```ast.AST```
    """
    node._location = [node.name] if hasattr(node, 'name') else []
    parent_location = []
    for _node in ast.walk(node):
        name = [_node.name] if hasattr(_node, 'name') else []
        for child in ast.iter_child_nodes(_node):
            if hasattr(child, 'name'):
                child._location = name + [child.name]
                parent_location = child._location
            elif isinstance(child, ast.arg):
                child._location = parent_location + [child.arg]

然后实现上述的一种方法ast.NodeTransformer

class RewriteAtQuery(ast.NodeTransformer):
    """
    Replace the node at query with given node

    :ivar search: Search query, e.g., ['class_name', 'method_name', 'arg_name']
    :ivar replacement_node: Node to replace this search
    """

    def __init__(self, search, replacement_node):
        """
        :param search: Search query
        :type search: ```List[str]```

        :param replacement_node: Node to replace this search
        :type replacement_node: ```ast.AST```
        """
        self.search = search
        self.replacement_node = replacement_node
        self.replaced = False

    def generic_visit(self, node):
        """
        Visit every node, replace once, and only if found

        :param node: AST node
        :type node: ```ast.AST```

        :returns: AST node, potentially edited
        :rtype: ```ast.AST```
        """
        if not self.replaced and hasattr(node, '_location') \
           and node._location == self.search:
            node = self.replacement_node
            self.replaced = True
        return ast.NodeTransformer.generic_visit(self, node)

你完成了=)


使用/测试:

parsed_ast = ast.parse(class_with_method_and_body_types_str)
annotate_ancestry(parsed_ast)
rewrite_at_query = RewriteAtQuery(
    search="C.method_name.dataset_name".split("."),
    replacement_node=arg(
        annotation=Name(ctx=Load(), id="int"),
        arg="dataset_name",
        type_comment=None,
    ),
).visit(parsed_ast)
self.assertTrue(rewrite_at_query.replaced, True)
# Additional test to compare AST produced with desired AST [see repo]

推荐阅读