python - 如何以复数计算 digamma 函数以便在 Tensorflow 中使用此函数(接受输入作为张量)?
问题描述
N= [-7.12843079e+02, -1.39668296e+02, -6.01626070e+01, -3.51688015e+01]
jax.scipy.special.digamma(N)
TypeError: digamma does not accept dtype complex64. Accepted dtypes are subtypes of floating.
我正在尝试使用 jax.scipy.special.digamma 以复数计算 digamma,但是,即使此包的文档说它可能很复杂,它仍然给我这个错误,这是文档所说的:
参数: z (array_like) – 实数或复数参数。
知道如何解决这个问题吗?或者是否有其他方法,例如其他库或其他包,允许我使用复数来计算 digamma 函数!?
解决方案
我有一个类似的问题。这就是我解决它的方法。漫长的道路。我知道有人会给出一个简洁的方法。
从定义psi 函数
我从这里使用了 Gamma 函数。确保输出在 JAX 中,否则您将无法使用 GRAD;
import jax.numpy as jnp
from jax import grad
def gamma_func_numeric(z):
g = 7
z -= 1
x = lanczos_coef[0]
for i in range(1, g+2):
x += lanczos_coef[i]/(z+i)
t = z + g + 0.5
return jnp.sqrt(2*jnp.pi)*jnp.power(t,(z+0.5))*jnp.exp(-t)*x
eta is a complex number
def psi_numeric(eta):
gamma_prime = grad(gamma_func_numeric, holomorphic=True)(eta)
gamma = gamma_func_numeric(eta)
return gamma_prime/gamma
让我们比较一下我们的结果:
scipy.special.psi(1.+2j)=(0.7145915153739777+1.3208072826422304j)
psi_numeric(1.+2j) = DeviceArray(0.7145916+1.3208078j, dtype=complex64)
推荐阅读
- python - 即使输入没有下降趋势,我的自动 ARIMA 也会显示下降趋势
- c# - EFMigrationsHistory 表为空
- python - Heroku Postgres 数据库在我的 Django 项目中不起作用
- javascript - 如何使用 svelte-routing 在当前路由器中导航?
- css - Bootstrap 5个浮动标签重叠
- flutter - 在 Flutter 中将数据从一个小部件传输到整个项目
- python - 返回一个数据框作为结果 Flask
- tensorflow - ValueError: 无法解析模型:pybind11::init(): factory function returned nullptr
- mongodb - Nestjs与Mongodb本机驱动程序注入连接问题
- api-doc - 使用 apidocjs 生成 api doc 时是否可以覆盖值?