python - 从 sympy 生成优化的倍频程代码
问题描述
我有一些巨大的矩阵要导出,其中只包含 sin(q)、cos(q) 和这些的 sum/muls。Sympy 可以计算并将其导出为八度 - 这太棒了!但是,由于这些是大型矩阵,我需要某种cse
甚至更好的专用优化。
我用 cse 找到了这个很棒的 C 代码教程。所以我尝试自己移植它,但我在打印机类的一些细节上失败了。我认为这是一个无限递归,导致RecursionError: maximum recursion depth exceeded
.
我的问题是:有没有一个例子 sympy-octave 代码生成和优化如何结合在一起?或者有人可以帮我让附加的mwe运行吗?
import sympy as sp
t = sp.symbols('t')
from sympy.printing.octave import OctaveCodePrinter
from sympy.printing.octave import Assignment
class matlabMatrixPrinter(OctaveCodePrinter):
def _print_ImmutableDenseMatrix(self, expr):
sub_exprs, simplified = sp.cse(expr)
lines = []
for var, sub_expr in sub_exprs:
lines.append( self._print(Assignment(var, sub_expr)))
M = sp.MatrixSymbol('M', *expr.shape)
return '\n'.join(lines) + '\n' + self._print(Assignment(M, expr))
tmp = sp.sin(t)+sp.sin(t)**2
tmp = sp.ImmutableDenseMatrix((1,1,tmp))
se, ex = sp.cse(tmp)
print((ex,se))
print('\n')
#tmp = sp.Matrix([2*sp.sin(t),sp.sin(t)])
p = matlabMatrixPrinter()
print(p.doprint(tmp))
编辑:我现在发现,return 语句中的第二个赋值也运行函数 _print_ImmutableDenseMatrix,所以这最终是一个递归。我不知道为什么在本教程中这对 C 代码没有问题,但在这里它递归运行。似乎只有简化表达式本身无法调用 self._print 函数的问题。也许有人对这些打印机有所了解,以及应该如何打印矩阵和这个单一的作业?!
解决方案
经过大量的实验,我觉得我仍然只了解 codePrinter 的有意工作流程背后的一些意图。然而,我写了一个子类,它完全符合我的预期(小心,因为这可能不适用于除矩阵之外的任何东西!)。
也许这对某人有用!对我来说,它绝对验证了 sympy 作为一个工作工具,因为否则成千上万的sin
评估将是绝对不可行的代码。
我仍然会对某人的评论和想法非常感兴趣,谁知道这些功能应该如何实现!
import sympy as sp
t = sp.symbols('t')
from sympy.printing.octave import OctaveCodePrinter
from sympy.printing.octave import Assignment
class matlabMatrixPrinter(OctaveCodePrinter):
def print2(self,expr_list,names=None):
sub_exprs, simplified = sp.cse(expr_list)
lines = []
for var, sub_expr in sub_exprs:
lines.append(self._print(Assignment(var, sub_expr)))
lines.append('')
for k,expr in enumerate(simplified):
if names:
M = sp.MatrixSymbol(names[k],*expr.shape)
else:
M = sp.MatrixSymbol('M{k}'.format(k=k), *expr.shape)
lines.append(self._print(Assignment(M,expr)))
result = ''
return '\n'.join(lines)
tmp = sp.Matrix([sp.sin(t)+sp.sin(t)**2 ])
tmp2 = sp.Matrix([sp.sin(t),sp.cos(t),2*sp.sin(t),sp.cos(t)**2])
p = matlabMatrixPrinter()
#print(p.print2([tmp,tmp2]))
print(p.print2([tmp,tmp2],['scalar_matrix','matrix']));
这给出了预期的输出:
x0 = sin(t);
x1 = cos(t);
scalar_matrix = x0.^2 + x0;
matrix = [x0; x1; 2*x0; x1.^2];
如上所述:使用风险自负:)
推荐阅读
- opengl-es - 如何在 OpenGLES 中为多个对象上的多个纹理设置片段着色器?
- firebase - 在 Flutter 中实现 Firebase 分析
- javascript - vue生产模式下静态资产url中的“localhost”
- elasticsearch - 在图形面板中显示 X 轴时间自定义范围的数据
- prestashop-1.7 - 解释 Prestashop 上的模块联系表
- java - 从 HashMap 返回 ArrayList
- angular - 角度 7 中的 require("http2")
- reactjs - React 如何让子组件拥有自己的状态
- c++ - 用 Qt::black 初始化 QColorDialog 总是返回黑色
- angular - Angular 获取 url 的最后一部分