python - 我在 numba 中的 python 程序没有加速
问题描述
我为计算磁化而编写的程序需要更多的计算时间。所以我切换到numba
。但我看不到任何速度增加。任何人都可以帮助我。我正在尝试在 24 核处理器中运行此代码。
import time
import datetime
import numpy as np from math
import pi
import numba from numba
import jit,njit,double,vectorize,float64,int64
import time
#%% parameters for the calculations
mu0 = 4e-7 * pi
h_planck=6.58212e-4# mev*ns
mub=5.78e-2#meV/T
g=2
s=2
T=2.0 #K
dt=0.5e-5
Kb=8.6e-2
Kjt=1.5
gamma =(g*mub)/h_planck #1/(T*ns)
alpha = 1
mus=mub*g*s
eA=np.array([-np.sqrt(2.0/3.0),0.0,-np.sqrt(1.0/3.0)])
eB=np.array([-np.sqrt(1.0/6.0),-np.sqrt(1.0/2.0),np.sqrt(1.0/3.0)])
eC=np.array([-np.sqrt(1.0/6.0),np.sqrt(1.0/2.0),np.sqrt(1.0/3.0)])
@njit
def dot(S1,eA,eB,eC):
result1=0.0
result2=0.0
result3=0.0
for i in range(3):
result1 += S1[i]*eA[i]
result2 += S1[i]*eB[i]
result3 += S1[i]*eC[i]
return result1,result2,result3
@njit
def jahnteller1(S1):
global Kjt
M,N,O=dot(S1,eA,eB,eC)
P,Q,R=M**5,N**5,O**5
X=3.0*Kjt*((eA*P+eB*Q+eC*R))
return X/mus
@njit
def thermal1():
mu, sigma = 0, 1 # mean and standard deviation
G = np.random.normal(mu, sigma, 3)
Hth1=G*np.sqrt((2*alpha*Kb*T)/(gamma*mus*dt))
return Hth1
#%% calculation of effective field
@njit
def h_eff(B,S1,eH):
Heff1 = eH*B+jahnteller1(S1)+thermal1()
return Heff1
#%% evaluating cross products
@njit
def cross1(S1,heff1):
result1=np.zeros(3)
a1, a2, a3 = S1[0], S1[1], S1[2]
b1, b2, b3 = heff1[0], heff1[1],heff1[2]
result1[0] = a2 * b3 - a3 * b2
result1[1] = a3 * b1 - a1 * b3
result1[2] = a1 * b2 - a2 * b1
return result1
@njit
def cross2(S1,X):
result2=np.zeros(3)
a1, a2, a3 = S1[0],S1[1],S1[2]
c1, c2, c3 = X[0],X[1],X[2]
result2[0] = a2 * c3 - a3 * c2
result2[1] = a3 * c1 - a1 * c3
result2[2] = a1 * c2 - a2 * c1
return result2
#%% Main function to calculate the Spin S1 by calculating the effective field
@njit
def llg(S1,dt, B,eH):
global gamma,alpha
N_init = int(5)
for i in range(N_init):
heff1 = h_eff(B,S1,eH)
X=cross1(S1,heff1)
Y=cross2(S1,X)
dS1dt = - gamma/(1+alpha**2) * X \
- alpha*gamma/(1+alpha**2) * Y
S1 += dt * dS1dt
normS1 = np.sqrt(S1[0]*S1[0]+S1[1]*S1[1]+S1[2]*S1[2])
S1 = S1/normS1
Savg=np.array([0.0,0.0,0.0])
Navg=N_init*10
for i in range(Navg):
heff1 = h_eff(B,S1,eH)
X=cross1(S1,heff1)
Y=cross2(S1,X)
dS1dt = - gamma/(1+alpha**2) * X \
- alpha*gamma/(1+alpha**2) * Y
S1 += dt * dS1dt
normS1 = np.sqrt(S1[0]*S1[0]+S1[1]*S1[1]+S1[2]*S1[2])
S1 = S1/normS1
Savg=Savg+S1
Savg=Savg/Navg
return Savg
#%% calculating dot product
@njit
def dott(S1,K):
result=0.0
for i in range(3):
result += S1[i]*K[i]
return result
#%% initialising magn
magn=np.zeros([25,3])
Th=[]
Ph=[]
B=5.0
theta=np.linspace(0.0,np.pi,5)
phi=np.linspace(0.0,2*np.pi,5)
for i in range(len(phi)):
for j in range(len(theta)):
M,N=phi[i],theta[j]
Th.append(N)
Ph.append(M)
#%% calling the main fuction
for i in range(25):
magn[i][0]=Ph[i]
magn[i][1]=Th[i]
eH=np.array([np.sin(Th[i])*np.cos(Ph[i]),np.sin(Th[i])*np.sin(Ph[i]),np.cos(Th[i])])
normH = np.sqrt(eH[0]*eH[0]+eH[1]*eH[1]+eH[2]*eH[2])
eH=eH/normH
S1=np.array([np.sin(Th[i])*np.cos(Ph[i]),np.sin(Th[i])*np.sin(Ph[i]),np.cos(Th[i])])
S1=llg(S1,dt,B,eH)
K=eH*B
Z=dott(S1,K)
E=-Z*g*mub*s
magn[i][2]=E
#%% printing magn
print(magn)
%timeit magn
解决方案
需要注意的几点:
- 您似乎没有为整个操作计时。我不确定最后一个
%timeit
表达式给了你什么 - 在我的机器上,运行您给出的代码大约需要 2.5 秒
- 请注意,第一次调用 Numba 函数非常慢,因为编译器会将代码转换为 llvm 代码。如果将 @njit 更改为 @njit(cache=True) 则结果将被缓存,并且以后的运行不会导致编译(直到您更改函数)。当我在我的机器上执行此操作时,第一次运行仍需要 2.5 秒,但第二次运行在 0.12 秒内完成。
- 这些都无法在纯 python 中运行这个函数,只需要 0.06 秒。
为什么?
在您的代码中出现这种情况的最大原因似乎是您在循环中从 Python 调用了许多小函数。调用 numba 函数会产生开销(我认为这比调用普通 python 函数的开销更糟,因为需要进行类型检查)。因此,如果您的 jitted 函数很简单,那么使用它们的好处就可以忽略不计(或者更糟糕的是,您会因此受到惩罚)。如果您可以更改代码,使整个逻辑(即主循环)也在 numba 函数中,它可能会比纯 python 更快。
推荐阅读
- javascript - 我想将 credential.access 令牌存储在一个变量中,我尝试了以下但没有成功
- c++ - C++ regex_search 与 perl 中的正则表达式匹配
- excel - 文件夹中的 Excel 电源查询
- google-maps-api-3 - Hybris 在地理定位期间抛出异常
- python - 增加数字 1-5 的组合
- python - 正则表达式(Python) - 匹配所需单词之前的所有内容
- java - 运行批处理时无法创建工件
- r - 如何在新引擎中包含绘图
- java - 创建新 MPart 时如何获取在 Application.e4xmi 中注册的 bundleclass - Eclipse RCP e4
- javascript - ajax成功功能后如何刷新/重新加载数据表