首页 > 解决方案 > 我在 numba 中的 python 程序没有加速

问题描述

我为计算磁化而编写的程序需要更多的计算时间。所以我切换到numba。但我看不到任何速度增加。任何人都可以帮助我。我正在尝试在 24 核处理器中运行此代码。

import time  
import datetime
import numpy as np from math 
import pi
import numba from numba 
import jit,njit,double,vectorize,float64,int64
import time 

 #%% parameters for the calculations
    mu0 = 4e-7 * pi
        h_planck=6.58212e-4# mev*ns
        mub=5.78e-2#meV/T
        g=2
        s=2
        T=2.0 #K
       dt=0.5e-5
       Kb=8.6e-2
       Kjt=1.5
       gamma =(g*mub)/h_planck #1/(T*ns) 
       alpha = 1
       mus=mub*g*s
eA=np.array([-np.sqrt(2.0/3.0),0.0,-np.sqrt(1.0/3.0)])
eB=np.array([-np.sqrt(1.0/6.0),-np.sqrt(1.0/2.0),np.sqrt(1.0/3.0)])
eC=np.array([-np.sqrt(1.0/6.0),np.sqrt(1.0/2.0),np.sqrt(1.0/3.0)])
@njit
def dot(S1,eA,eB,eC):
    result1=0.0
    result2=0.0
    result3=0.0
    for i in range(3):
        result1 += S1[i]*eA[i]
        result2 += S1[i]*eB[i]
        result3 += S1[i]*eC[i]

    return result1,result2,result3
@njit
def jahnteller1(S1):
    global Kjt
    M,N,O=dot(S1,eA,eB,eC)
    P,Q,R=M**5,N**5,O**5
    X=3.0*Kjt*((eA*P+eB*Q+eC*R))
    return X/mus
@njit
def thermal1():
    mu, sigma = 0, 1 # mean and standard deviation
    G = np.random.normal(mu, sigma, 3)
    Hth1=G*np.sqrt((2*alpha*Kb*T)/(gamma*mus*dt))
    return Hth1
#%% calculation of effective field
@njit
def h_eff(B,S1,eH):
    Heff1 = eH*B+jahnteller1(S1)+thermal1()
    return  Heff1
#%% evaluating cross products
@njit
def cross1(S1,heff1):
    result1=np.zeros(3)
    a1, a2, a3 = S1[0], S1[1], S1[2]
    b1, b2, b3 = heff1[0], heff1[1],heff1[2]
    result1[0] = a2 * b3 - a3 * b2
    result1[1] = a3 * b1 - a1 * b3
    result1[2] = a1 * b2 - a2 * b1
    return result1
@njit
def cross2(S1,X):
    result2=np.zeros(3)
    a1, a2, a3 = S1[0],S1[1],S1[2]
    c1, c2, c3 = X[0],X[1],X[2]
    result2[0] = a2 * c3 - a3 * c2
    result2[1] = a3 * c1 - a1 * c3
    result2[2] = a1 * c2 - a2 * c1
    return result2
#%% Main function to calculate the Spin S1 by calculating the effective field
 @njit
def llg(S1,dt, B,eH):
    global gamma,alpha
    N_init = int(5)
    for i in range(N_init):
        heff1 = h_eff(B,S1,eH)
        X=cross1(S1,heff1)
        Y=cross2(S1,X)
        dS1dt = - gamma/(1+alpha**2) * X \
           - alpha*gamma/(1+alpha**2) * Y
        S1 += dt * dS1dt
        normS1 = np.sqrt(S1[0]*S1[0]+S1[1]*S1[1]+S1[2]*S1[2])
        S1 = S1/normS1
    Savg=np.array([0.0,0.0,0.0])
    Navg=N_init*10
    for i in range(Navg):
        heff1 = h_eff(B,S1,eH)
        X=cross1(S1,heff1)
        Y=cross2(S1,X)
        dS1dt = - gamma/(1+alpha**2) * X \
           - alpha*gamma/(1+alpha**2) * Y
        S1 += dt * dS1dt
        normS1 = np.sqrt(S1[0]*S1[0]+S1[1]*S1[1]+S1[2]*S1[2])
        S1 = S1/normS1
        Savg=Savg+S1
    Savg=Savg/Navg
    return Savg  
#%% calculating dot product
@njit
def dott(S1,K):
    result=0.0
    for i in range(3):
        result += S1[i]*K[i]
    return result


 #%% initialising magn
        magn=np.zeros([25,3]) 
        Th=[]
        Ph=[]
        B=5.0
        theta=np.linspace(0.0,np.pi,5)
        phi=np.linspace(0.0,2*np.pi,5)
    for i in range(len(phi)):
        for j in range(len(theta)):
            M,N=phi[i],theta[j]
            Th.append(N)
            Ph.append(M)

#%% calling the main fuction
for i in range(25):
    magn[i][0]=Ph[i]
    magn[i][1]=Th[i]
    eH=np.array([np.sin(Th[i])*np.cos(Ph[i]),np.sin(Th[i])*np.sin(Ph[i]),np.cos(Th[i])])
    normH = np.sqrt(eH[0]*eH[0]+eH[1]*eH[1]+eH[2]*eH[2])
    eH=eH/normH
    S1=np.array([np.sin(Th[i])*np.cos(Ph[i]),np.sin(Th[i])*np.sin(Ph[i]),np.cos(Th[i])])
    S1=llg(S1,dt,B,eH)
    K=eH*B
    Z=dott(S1,K)
    E=-Z*g*mub*s
    magn[i][2]=E

#%% printing magn
print(magn)
%timeit magn

标签: pythonnumpynumba

解决方案


需要注意的几点:

  1. 您似乎没有为整个操作计时。我不确定最后一个%timeit表达式给了你什么
  2. 在我的机器上,运行您给出的代码大约需要 2.5 秒
  3. 请注意,第一次调用 Numba 函数非常慢,因为编译器会将代码转换为 llvm 代码。如果将 @njit 更改为 @njit(cache=True) 则结果将被缓存,并且以后的运行不会导致编译(直到您更改函数)。当我在我的机器上执行此操作时,第一次运行仍需要 2.5 秒,但第二次运行在 0.12 秒内完成。
  4. 这些都无法在纯 python 中运行这个函数,只需要 0.06 秒。

为什么?

在您的代码中出现这种情况的最大原因似乎是您在循环中从 Python 调用了许多小函数。调用 numba 函数会产生开销(我认为这比调用普通 python 函数的开销更糟,因为需要进行类型检查)。因此,如果您的 jitted 函数很简单,那么使用它们的好处就可以忽略不计(或者更糟糕的是,您会因此受到惩罚)。如果您可以更改代码,使整个逻辑(即主循环)也在 numba 函数中,它可能会比纯 python 更快。


推荐阅读