首页 > 解决方案 > 仅适用于实数的 Lambdifying Sympy 函数

问题描述

为了性能,我正在尝试将三次方程的输出转换为 sympy 之外的函数。

import sympy as sp

u, x, y = sp.symbols('u x y')

eq = - y - sp.Integral(x**2, (x, 1, u))

solved = sp.solveset(eq.doit(), u, domain=sp.S.Reals)

lam = sp.lambdify(y, sp.solveset(eq.doit(), u, domain=sp.S.Reals))

将任何数字放入lamthrows NameError: name 'Intersection' is not defined中,即使在Intersection专门导入时也是如此。

要求似乎有点多,考虑到开销,我不希望这能起作用,但是有没有办法让这个输出y作为输入,作为我不需要依赖 sympy 的外部函数做计算?我尝试输入一个值列表作为Matrix对象,但这使输出更加复杂,并且似乎在任何地方都没有正确的答案。

标签: pythonsympy

解决方案


运行您的代码并回答我的问题:

In [21]: 
    ...: u, x, y = symbols('u x y')
    ...: 
    ...: eq = - y - Integral(x**2, (x, 1, u))
    ...: 
    ...: solved = solveset(eq.doit(), u, domain=S.Reals)
    ...: 
    ...: lam = lambdify(y, solveset(eq.doit(), u, domain=S.Reals))

In [22]: solved
Out[22]: 
    ⎧3 ___________              ⎫
ℝ ∩ ⎨╲╱ │3⋅y - 1│ ⋅sign(1 - 3⋅y)⎬
    ⎩                           ⎭

In [23]: print(lam.__doc__)
Created with lambdify. Signature:

func(y)

Expression:

Intersection(FiniteSet(Abs(3*y - 1)**(1/3)*sign(1 - 3*y)), Reals)

Source code:

def _lambdifygenerated(y):
    return (  # Not supported in Python with SciPy:
  # FiniteSet
Intersection(FiniteSet(Abs(3*y - 1)**(1/3)*sign(1 - 3*y)), Reals))

numpy/scipy没有任何实现Intersection或的功能FiniteSetsympy无法翻译这些功能。

===

随着你的变化:

In [25]: ed=eq.doit()

In [26]: ed
Out[26]: 
   3        
  u        1
- ── - y + ─
  3        3

In [27]: 

In [27]: lam = lambdify(y,ed)

In [28]: print(lam.__doc__)
Created with lambdify. Signature:

func(y)

Expression:

-u**3/3 - y + 1/3

Source code:

def _lambdifygenerated(y):
    return (-1/3*u**3 - y + 1/3)


Imported modules:



In [29]: vals = lam(np.arange(10))

In [30]: vals
Out[30]: 
array([0.333333333333333 - 0.333333333333333*u**3,
       -0.333333333333333*u**3 - 0.666666666666667,
       -0.333333333333333*u**3 - 1.66666666666667,
       -0.333333333333333*u**3 - 2.66666666666667,
       -0.333333333333333*u**3 - 3.66666666666667,
       -0.333333333333333*u**3 - 4.66666666666667,
       -0.333333333333333*u**3 - 5.66666666666667,
       -0.333333333333333*u**3 - 6.66666666666667,
       -0.333333333333333*u**3 - 7.66666666666667,
       -0.333333333333333*u**3 - 8.66666666666667], dtype=object)

推荐阅读