首页 > 解决方案 > 替换数学公式中的点积运算符

问题描述

如何将点积(@)的pythons内部符号替换为numpy的numpy.dot?例如,公式 m.a@x + m.b@y应转换为np.dot(m.a, x) + np.dot(m.b, y).

我最初的想法是使用正则表达式在 @ (m.am.b上面的示例中)之前和之后查找文本,然后将它们放入 dot 函数中。这是我想象的使用正则表达式的方式:

# m.a, m.b, x and y are vectors of some equal size
formula = "m.a@x + m.b@y"
before_dots, after_dots = some_regex_function(formula)
result = eval(f"np.dot({before_dot[0]},{after_dot[0]}) + np.dot({before_dot[1]},{after_dot[1]})")

标签: pythonregexreplace

解决方案


使用astandastor模块,您可以解析代码并将运算符为矩阵乘法的所有二元运算节点替换为np.dot调用。

import ast
import astor


class ReplaceNpDot(astor.TreeWalk):
    def post_BinOp(self):
        node = self.cur_node
        if isinstance(node.op, ast.MatMult):
            np = ast.Name(id="np", ctx=ast.Load())
            np_dot = ast.Attribute(np, 'dot', ctx=ast.Load())
            self.replace(ast.Call(
                np_dot,
                args=[node.left, node.right],
                keywords=[],
                startargs=[]
            ))
        else:
            return node


# define m so it works...
# ...

# replace @ with np.dot
tree = ast.parse("m.a@x + m.b@y", mode='eval')
walker = ReplaceNpDot()
walker.walk(tree)

# print source code
print(astor.to_source(tree))

# run code
code = compile(ast.fix_missing_locations(tree), '<string>', 'eval')
exec(code)

推荐阅读