首页 > 解决方案 > 提高求幂性能

问题描述

我有一个由 MCMC 使用的相当简单的函数(参见下面的代码),这意味着它被调用了数百万次。据我所知,大部分时间都花在对数组求幂上,我想不出提高其性能的方法。它目前占用了 MCMC 总运行时间的约 15%,因此每一点改进都很重要。

这个功能可以做得更快吗?

import numpy as np
import time as t

def f1(abc, emax, rnd, arr):
    """This function should be as fast as possible"""

    # Only the first 3 sub-arrays are modified
    final_arr = []
    for i, (a, b, c) in enumerate(abc):

        # This is an 'error' obtained through this function
        # It always uses arr[0], but the a,b,c values change
        sigma = a * np.exp(b * arr[0]) + c

        # Clip values at 'emax'
        sigma[sigma > emax[i]] = emax[i]

        # Add the errors to arr[i], 'rnd' is a random array
        # of floats normally distributed with mean 0 and
        # stdev 1.
        final_arr.append(arr[i] + rnd[:len(arr[i])] * sigma)

    return final_arr

# Some initial data with the proper shape. This data does
# not change with iterations
emax = [.05, .1, .17]
abc = [
    [0.01068788, 0.13260967, -0.03015476],
    [0.01068788, 0.13260967, -0.03015476],
    [0.01068788, 0.13260967, -0.03015476]]
rnd = np.random.normal(0., 1., 1000000)

t1, t2 = 0., 0.
for _ in range(10000):

    # Array of data with proper shape.
    # This array changes with successive iterations.
    arr1 = np.random.uniform(10., 30., (1, 1000))
    arr2 = np.random.uniform(0., .3, (9, 1000))
    arr = np.concatenate((arr1, arr2))

    s = t.time()
    f1(abc, emax, rnd, arr)
    t1 += t.time() - s

print(t1)

标签: pythonperformancenumpy

解决方案


您可以使用广播来向量化操作并节省 30% 的计算时间。只需确保首先从值列表中创建 numpy 数组:

def f2(abc, eamx, rnd, arr):
    sigma = abc[:, 0, None] * np.exp(abc[:, 1, None] * arr[0, :]) + abc[:, 2, None]
    sigma = np.clip(sigma, a_min=None, a_max=emax[:, None])
    final_arr = arr[:len(sigma), :]
    return final_arr + rnd[:final_arr.shape[1]] * sigma

emax = np.asarray(emax)
abc = np.asarray(abc)

np.allclose(f1(abc, emax, rnd, arr), f2(abc, emax, rnd, arr))
# True

%timeit f1(abc, emax, rnd, arr)
78.3 µs ± 1.22 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit f2(abc, emax, rnd, arr)
54.8 µs ± 1.09 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

推荐阅读