python - 对数伽玛函数的快速算法
问题描述
我正在尝试编写一个快速算法来计算对数伽玛函数。目前我的实现似乎很幼稚,只是迭代 1000 万次来计算 gamma 函数的日志(我也在使用 numba 来优化代码)。
import numpy as np
from numba import njit
EULER_MAS = 0.577215664901532 # euler mascheroni constant
HARMONC_10MIL = 16.695311365860007 # sum of 1/k from 1 to 10,000,000
@njit(fastmath=True)
def gammaln(z):
"""Compute log of gamma function for some real positive float z"""
out = -EULER_MAS*z - np.log(z) + z*HARMONC_10MIL
n = 10000000 # number of iters
for k in range(1,n+1,4):
# loop unrolling
v1 = np.log(1 + z/k)
v2 = np.log(1 + z/(k+1))
v3 = np.log(1 + z/(k+2))
v4 = np.log(1 + z/(k+3))
out -= v1 + v2 + v3 + v4
return out
我根据scipy.special.gammaln实现对我的代码进行了计时,而我的代码实际上慢了 100,000 倍。所以我在做一些非常错误或非常幼稚的事情(可能两者兼而有之)。尽管与 scipy 相比,我的答案至少在小数点后 4 位以内是正确的。
我试图阅读实现 scipy 的 gammaln 函数的 _ufunc 代码,但是我不明白 _gammaln 函数所写的 cython 代码。
有没有更快、更优化的方法可以计算对数伽玛函数?我如何理解 scipy 的实现,以便将其与我的结合起来?
解决方案
您的函数的运行时间将随着迭代次数线性扩展(直到一些恒定的开销)。所以减少迭代次数是加速算法的关键。虽然HARMONIC_10MIL
预先计算是一个聪明的想法,但当您截断系列时,它实际上会导致更差的准确性;仅计算系列的一部分结果会提供更高的准确性。
下面的代码是上面发布的代码的修改版本(尽管使用cython
代替numba
)。
from libc.math cimport log, log1p
cimport cython
cdef:
float EULER_MAS = 0.577215664901532 # euler mascheroni constant
@cython.cdivision(True)
def gammaln(float z, int n=1000):
"""Compute log of gamma function for some real positive float z"""
cdef:
float out = -EULER_MAS*z - log(z)
int k
float t
for k in range(1, n):
t = z / k
out += t - log1p(t)
return out
如下图所示,即使经过 100 次近似,它也能得到一个接近的近似值。
在 100 次迭代中,它的运行时间与以下数量级相同scipy.special.gammaln
:
%timeit special.gammaln(5)
# 932 ns ± 19 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit gammaln(5, 100)
# 1.25 µs ± 20.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
剩下的问题当然是要使用多少次迭代。该函数log1p(t)
可以扩展为小的泰勒级数t
(这与大的极限有关k
)。尤其是,
log1p(t) = t - t ** 2 / 2 + ...
这样,对于大k
, sum 的参数变为
t - log1p(t) = t ** 2 / 2 + ...
因此,总和的自变量在二阶之前为零,如果足够小t
,则可以忽略不计。t
换言之,迭代次数应至少与 一样大z
,最好至少大一个数量级。
但是,如果可能的话,我会坚持使用scipy
经过良好测试的实现。
推荐阅读
- html - 打印div中生成的条码
- python - PyQt 和 gpiozero 回调
- firebase - 用于 UITableViewCell 内的 UILabel 的带有换行符“\n”的字符串不会中断
- ios - SwiftUI 中的约束有什么用?
- python-3.x - 如何针对 Intel Python 3.6.x(64 位)正确安装 PyQt5 模块?
- rest - 客户端与事件溯源的交互
- struct - 有序结构中的最后一个键
- css - 如何将覆盖在 div 上的图像居中 * 但是 * 使图像以底部而不是中心为中心?
- oracle - Oracle Dataguard TAF(透明应用程序故障转移)问题
- python - 如果我有一组坐标,我怎样才能得到国家,只有 Requests 库?