python - 为什么使用 Numba 时这段代码不更快?
问题描述
为什么使用 numba jit 进行这种蒙特卡罗模拟没有更快的速度?删除@jit
使它运行得更快一些。但是,我认为这些循环是 numba 擅长的......
import numpy as np
from numba import jit
T = 1000
ALPHA = 0.11
BETA = 0.22
GAMMA = 0.33
@jit(nopython=True, fastmath=True)
def sim(n):
mu = np.array([0.0, 0.1]).reshape(2,1)
rho = 0.1
sigma = np.array([[1, rho*4],[rho*4, 4^2]])
A = np.linalg.cholesky(sigma)
out = np.empty((n, 2))
for i in range(n):
# (a)
u = np.random.randn(T)
# X = np.random.multivariate_normal(mu, sigma, T)
X = mu + A @ np.random.randn(2,T)
X = np.concatenate((np.ones((T, 1)), X.T), axis=1)
y = X @ np.array([ALPHA, BETA, GAMMA]) + u
# (b)
thetahat = np.linalg.solve(X.T @ X, X.T @ y)
Xf = X[:,:2].copy()
thetatilde = np.linalg.solve(Xf.T @ Xf, Xf.T @ y)
out[i,:] = (thetahat[1], thetatilde[1])
return out
n = 10**5
s = sim(n)
print(s)
解决方案
正如文档所述:
首先,回想一下,Numba 必须在执行函数的机器代码版本之前为给定的参数类型编译函数,这需要时间。但是,一旦编译完成,Numba 会针对所呈现的特定类型的参数缓存函数的机器代码版本。如果以相同的类型再次调用它,它可以重用缓存的版本,而不必再次编译。
测量性能时一个真正常见的错误是不考虑上述行为并使用简单的计时器对代码进行一次计时,该计时器包括在执行时间中编译函数所花费的时间。
简单地说:在第一次执行(在你的情况下,唯一的一次)中,numba 将其编译为机器代码。这需要时间。如果您再次运行它,那么您将看到不同之处。
例如:
import numpy as np
from numba import jit
import time
T = 1000
ALPHA = 0.11
BETA = 0.22
GAMMA = 0.33
@jit(nopython=True, fastmath=True)
def sim(n):
mu = np.array([0.0, 0.1]).reshape(2,1)
rho = 0.1
sigma = np.array([[1, rho*4],[rho*4, 4^2]])
A = np.linalg.cholesky(sigma)
out = np.empty((n, 2))
for i in range(n):
# (a)
u = np.random.randn(T)
# X = np.random.multivariate_normal(mu, sigma, T)
X = mu + A @ np.random.randn(2,T)
X = np.concatenate((np.ones((T, 1)), X.T), axis=1)
y = X @ np.array([ALPHA, BETA, GAMMA]) + u
# (b)
thetahat = np.linalg.solve(X.T @ X, X.T @ y)
Xf = X[:,:2].copy()
thetatilde = np.linalg.solve(Xf.T @ Xf, Xf.T @ y)
out[i,:] = (thetahat[1], thetatilde[1])
return out
n = 10**2
start = time.time()
sim(n)
end = time.time()
print("Elapsed (with compilation) = %s" % (end - start))
s = sim(n)
start = time.time()
sim(n)
end = time.time()
print("Elapsed (with compilation) = %s" % (end - start))
print(s)
在这里,我运行了两次模拟。第一次运行耗时 5 秒,但第二次运行仅耗时 0.01 秒。Numba 确实有所改善。
当你有一个你多次使用的功能时,Numba 很有用。对于单次执行,numba 没有用。
推荐阅读
- python - 如何从被调用函数“继续”调用函数中的循环?
- javascript - 如何在 GraphQl JS 中实现带有破折号的枚举
- python - 在 Python(matplotlib)中将水平线插入直方图中?
- mysql - 警告:#4038 参数 1 中 JSON 文本中的语法错误,函数 'st_geomfromgeojson' 在位置 29
- macos-big-sur - 无法在 macOS Big Sur 中执行 Java Robot 类
- python - 读取用户的击键
- java - JUnit - 如何以不同的参数以相同的方法测试重复指令?
- reactjs - 反应本机中的反应导航卸载问题
- java - 一个虚拟盒子上的几个类似的Java程序
- php - Laravel 批处理队列导致 MySQL 错误