python - cython vs numba 的性能
问题描述
嘿,我目前正在使用 Python 的热力学流体相平衡模块。为此,我需要对活动系数模型(如 NRTL)进行编程,该模型涉及多个求和。为了提高模块的性能,我尝试使用 numba 来 jit 函数:
@jit(cache=True)
def NRTL(X,T,g, alpha, g1):
'''
NRTL activity coefficient model.
input
X: array like, vector of molar fractions
T: float, absolute temperature in K.
g: array like, matrix of energy interactions in K.
g1: array_like, matrix of energy interactions in K^2
alpha: float, aleatory factor.
tau = ((g + g1/T)/T)
output
lngama: array_like, natural logarithm of activify coefficient
'''
tau = g + g1*T
tau /= T
nc=len(X)
G=np.exp(-alpha*tau)
lngama=np.zeros_like(X)
for i in range(nc):
SumC=SumD=SumE=0
for j in range(nc):
A=X[j]*G[i,j]
SumA=SumB=0
for k in range(nc):
SumA +=X[k]*G[k,j]
SumB +=X[k]*G[k,j]*tau[k,j]
SumC +=A/SumA*(tau[i,j]-SumB/SumA)
SumD+=X[j]*G[j,i]*tau[j,i]
SumE+=X[j]*G[j,i]
lngama[i]=SumD/SumE+SumC
return lngama
我正在尝试新的选择,例如 cython,但我的性能不如 numbas 的 jit。
import numpy as np
cimport numpy as np
cimport cython
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cdef double[:] nrtlaux(double [:] X, double [:,::1] G, double [:,::1] tau, int nc):
cdef int i, j, k
cdef double A, SumA, SumB, SumC, SumD, SumE, aux1, aux2
cdef double [:] lngama = np.zeros(nc)
for i in range(nc):
SumC = SumD = SumE = 0.
for j in range(nc):
A = X[j]*G[i,j]
SumA = SumB = 0.
for k in range(nc):
aux1 = X[k]*G[k,j]
SumA += aux1
SumB += aux1*tau[k,j]
SumC += A/SumA*(tau[i,j]-SumB/SumA)
aux2 = X[j]*G[j,i]
SumD += aux2*tau[j,i]
SumE += aux2
lngama[i] = SumD/SumE+SumC
return lngama
def NRTL(np.ndarray[double, ndim=1] X, double T, np.ndarray[double, ndim=2] g,
np.ndarray[double, ndim=2] alpha, np.ndarray[double, ndim=2] g1):
cdef int nc = len(X)
cdef:
double[:,::1] tau = (g/T + g1)
double[:,::1] G = np.exp( -alpha * tau )
lngama = nrtlaux(X, G, tau, nc)
return np.asarray(lngama)
我使用以下参数来评估函数:
X = np.array([0.5,0.4,0.1])
g = np.array([[0,35.00002657,463.719316],[341.00001923,0,96.02154497],[1194.42262, 534.77089478,0]])
alpha = np.array([[0,0.3456916919878884,0.242020522],[0.3456916919878884,0,0.54 ],[0.242020522,0.54 ,0]])
g1 = np.zeros_like(g)
T = 350.
我得到了以下结果:
%timeit NRTL(X,T,g,alpha, g1) #cython
13.9 µs ± 489 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit nrtltp(X,T,g,alpha, g1) #numba
1.82 µs ± 35 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
我对 jitted 函数的良好结果感到有点惊讶,我也是 cython 的初学者,所以我希望有任何建议来提高性能?
解决方案
推荐阅读
- kubernetes - Kubernetes 在新部署中维护负载均衡器的外部 IP
- python - 爬取的数据可以在终端上完整打印,但不能完整写成文字
- reactjs - React-Bootstrap 无法内联对齐导航链接
- hyperledger-fabric - 如何区分 Hyperledger Fabric 中“cryptogen”生成的客户端和对等证书?
- faunadb - 动物区系是实时的吗?
- sql - Kotlin androidx.room.fts4 文档看起来像是使用 Java 编写的!有没有人使用 Kotlin 实现了 fts4 实体?
- django - Django:Charfield的get_default显示空字符串
- c - 如何在汇编中表达'\n'?
- r - extrafont::loadfonts(device = "win") 在 Rprofile 中不起作用,但在控制台中很好
- python - Python ValueError 但形状匹配