python - 提高求幂性能
问题描述
我有一个由 MCMC 使用的相当简单的函数(参见下面的代码),这意味着它被调用了数百万次。据我所知,大部分时间都花在对数组求幂上,我想不出提高其性能的方法。它目前占用了 MCMC 总运行时间的约 15%,因此每一点改进都很重要。
这个功能可以做得更快吗?
import numpy as np
import time as t
def f1(abc, emax, rnd, arr):
"""This function should be as fast as possible"""
# Only the first 3 sub-arrays are modified
final_arr = []
for i, (a, b, c) in enumerate(abc):
# This is an 'error' obtained through this function
# It always uses arr[0], but the a,b,c values change
sigma = a * np.exp(b * arr[0]) + c
# Clip values at 'emax'
sigma[sigma > emax[i]] = emax[i]
# Add the errors to arr[i], 'rnd' is a random array
# of floats normally distributed with mean 0 and
# stdev 1.
final_arr.append(arr[i] + rnd[:len(arr[i])] * sigma)
return final_arr
# Some initial data with the proper shape. This data does
# not change with iterations
emax = [.05, .1, .17]
abc = [
[0.01068788, 0.13260967, -0.03015476],
[0.01068788, 0.13260967, -0.03015476],
[0.01068788, 0.13260967, -0.03015476]]
rnd = np.random.normal(0., 1., 1000000)
t1, t2 = 0., 0.
for _ in range(10000):
# Array of data with proper shape.
# This array changes with successive iterations.
arr1 = np.random.uniform(10., 30., (1, 1000))
arr2 = np.random.uniform(0., .3, (9, 1000))
arr = np.concatenate((arr1, arr2))
s = t.time()
f1(abc, emax, rnd, arr)
t1 += t.time() - s
print(t1)
解决方案
您可以使用广播来向量化操作并节省 30% 的计算时间。只需确保首先从值列表中创建 numpy 数组:
def f2(abc, eamx, rnd, arr):
sigma = abc[:, 0, None] * np.exp(abc[:, 1, None] * arr[0, :]) + abc[:, 2, None]
sigma = np.clip(sigma, a_min=None, a_max=emax[:, None])
final_arr = arr[:len(sigma), :]
return final_arr + rnd[:final_arr.shape[1]] * sigma
emax = np.asarray(emax)
abc = np.asarray(abc)
np.allclose(f1(abc, emax, rnd, arr), f2(abc, emax, rnd, arr))
# True
%timeit f1(abc, emax, rnd, arr)
78.3 µs ± 1.22 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit f2(abc, emax, rnd, arr)
54.8 µs ± 1.09 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
推荐阅读
- php - Select2 从服务器获取值
- python - python子进程打印和存储标准输出
- git - 将我的分支合并到主分支的最安全方法?
- extjs - Sencha Cmd 版本与所需版本不兼容
- javascript - 有什么方法可以在 HighStocks 的测量注释工具栏中引入自定义功能?不是样式和 CSS,而是功能
- javascript - 应用按钮需要禁用但无法禁用 Toggler2 按钮的 onclick
- android - 基于构建类型的 gradle 模块解析
- database - 如何填充 ODS 表以从其他表中获取新生成的 ID
- python - 如何在python中没有收到任何套接字超时的数据后执行代码
- reactjs - Next.JS SSG 被 Redux Persist 破坏