首页 > 解决方案 > 为什么使用 Numba 时这段代码不更快?

问题描述

为什么使用 numba jit 进行这种蒙特卡罗模拟没有更快的速度?删除@jit使它运行得更快一些。但是,我认为这些循环是 numba 擅长的......

import numpy as np
from numba import jit

T = 1000
ALPHA = 0.11
BETA = 0.22
GAMMA = 0.33

@jit(nopython=True, fastmath=True)
def sim(n):
  mu = np.array([0.0, 0.1]).reshape(2,1)
  rho = 0.1
  sigma = np.array([[1, rho*4],[rho*4, 4^2]])
  A = np.linalg.cholesky(sigma)

  out = np.empty((n, 2))
  for i in range(n):
    # (a)
    u = np.random.randn(T)

    # X = np.random.multivariate_normal(mu, sigma, T)
    X = mu + A @ np.random.randn(2,T)
    X = np.concatenate((np.ones((T, 1)), X.T), axis=1)

    y = X @ np.array([ALPHA, BETA, GAMMA]) + u

    # (b)
    thetahat = np.linalg.solve(X.T @ X, X.T @ y)

    Xf = X[:,:2].copy()
    thetatilde = np.linalg.solve(Xf.T @ Xf, Xf.T @ y)

    out[i,:] = (thetahat[1], thetatilde[1])

  return out

n = 10**5
s = sim(n)
print(s)

标签: pythonnumba

解决方案


正如文档所述:

首先,回想一下,Numba 必须在执行函数的机器代码版本之前为给定的参数类型编译函数,这需要时间。但是,一旦编译完成,Numba 会针对所呈现的特定类型的参数缓存函数的机器代码版本。如果以相同的类型再次调用它,它可以重用缓存的版本,而不必再次编译。

测量性能时一个真正常见的错误是不考虑上述行为并使用简单的计时器对代码进行一次计时,该计时器包括在执行时间中编译函数所花费的时间。

简单地说:在第一次执行(在你的情况下,唯一的一次)中,numba 将其编译为机器代码。这需要时间。如果您再次运行它,那么您将看到不同之处。

例如:

import numpy as np
from numba import jit
import time

T = 1000
ALPHA = 0.11
BETA = 0.22
GAMMA = 0.33

@jit(nopython=True, fastmath=True)
def sim(n):
  mu = np.array([0.0, 0.1]).reshape(2,1)
  rho = 0.1
  sigma = np.array([[1, rho*4],[rho*4, 4^2]])
  A = np.linalg.cholesky(sigma)

  out = np.empty((n, 2))
  for i in range(n):
    # (a)
    u = np.random.randn(T)

    # X = np.random.multivariate_normal(mu, sigma, T)
    X = mu + A @ np.random.randn(2,T)
    X = np.concatenate((np.ones((T, 1)), X.T), axis=1)

    y = X @ np.array([ALPHA, BETA, GAMMA]) + u

    # (b)
    thetahat = np.linalg.solve(X.T @ X, X.T @ y)

    Xf = X[:,:2].copy()
    thetatilde = np.linalg.solve(Xf.T @ Xf, Xf.T @ y)

    out[i,:] = (thetahat[1], thetatilde[1])

  return out

n = 10**2
start = time.time()
sim(n)
end = time.time()
print("Elapsed (with compilation) = %s" % (end - start))

s = sim(n)
start = time.time()
sim(n)
end = time.time()
print("Elapsed (with compilation) = %s" % (end - start))
print(s)

在这里,我运行了两次模拟。第一次运行耗时 5 秒,但第二次运行仅耗时 0.01 秒。Numba 确实有所改善。

当你有一个你多次使用的功能时,Numba 很有用。对于单次执行,numba 没有用。


推荐阅读