首页 > 解决方案 > Numba/numPy 多次运行速度差异/优化

问题描述

我看到了使用 Numba 的一些特殊性能,并且还希望进一步优化 JIT 循环。

初始化并生成一些实际相关的数据:

import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from time import time
import numba

times =             np.arange(datetime(2000, 1, 1), datetime(2020, 2, 1), timedelta(minutes=10)).astype(np.datetime64)
tlen =              len(times)
A, Z =              np.array(['A', 'Z']).view('int32')
symbol_names =      np.random.randint(low=A, high=Z, size=1 * 7, dtype='int32').view(f'U{7}')
times =             np.concatenate([times] * 1)
names =             np.array([y for x in [[s] * tlen for s in symbol_names] for y in x])
open_column =       np.random.randint(low=40, high=60, size=len(times), dtype='uint32')
high_column =       np.random.randint(low=50, high=70, size=len(times), dtype='uint32')
low_column =        np.random.randint(low=30, high=50, size=len(times), dtype='uint32')
close_column =      np.random.randint(low=40, high=60, size=len(times), dtype='uint32')
df = pd.DataFrame({'open': open_column, 'high': high_column, 'low': low_column, 'close': close_column}, index=[names, times])
df.index = df.index.set_names(['Symbol', 'Date'])
df['entry'] = np.select( [df.open > df.open.shift(), False], (df.close, -1), np.nan)
df['exit'] =  df.close.where(df.high > df.open*1.33, np.nan)

定时功能:

def timing(f):
    def wrap(*args):
        time1 = time()
        ret = f(*args)
        time2 = time()
        print('{:s} function took {:.3f} s'.format(f.__name__, (time2-time1)))
        return ret
    return wrap

JIT编译函数:

@numba.jit(nopython=True)
def entry_exit(arr, limit=0, stop=0, tbe=0):
    is_active = 0
    bars_held = 0
    limit_target = np.inf
    stop_target = -np.inf
    result = np.empty(arr.shape[0], dtype='float32')

    for n in range(arr.shape[0]):
        ret = 0
        if is_active == 1:
            bars_held += 1
            if arr[n][2] < stop_target:
                ret = stop_target
                is_active = 0
            elif arr[n][1] > limit_target:
                ret = limit_target
                is_active = 0
            elif bars_held >= tbe:
                ret = arr[n][3]
                is_active = 0
            elif arr[n][5] > 0:
                ret = arr[n][3]
                is_active = 0
        if is_active == 0:
            if arr[n][4] > 0:
                is_active = 1
                bars_held = 0
                if stop != 0:
                    stop_target = arr[n][3] * stop
                if limit != 0:
                    limit_target = arr[n][3] * limit
        result[n] = ret
    return result

测试:

@timing
def run_one(arr):
    entry_exit(arr, limit=1.20, stop=0.50, tbe=5)

@timing
def run_ten(arr):
    for _ in range(10):
        entry_exit(arr, limit=1.20, stop=0.50, tbe=5)

arr = df[['open', 'high', 'low', 'close', 'entry', 'exit']].to_numpy()
run_one(arr)
run_ten(arr)

在本机 Python 中运行它时,我得到:

说得通。

当我在 JIT 中运行相同的程序时,我得到了完全不同的东西:

为什么会这样?我也很想知道如何进一步加速该功能,因为当前的速度增益虽然显着,但还不够。

标签: pythonnumpyjitnumba

解决方案


numba.jit首次使用时将编译该函数。这使得函数的第一次执行变得昂贵,而随后的执行则便宜得多。

您的测试可能会运行run_one- 它调用entry_exitwhich numba 编译 - 因此编译速度很慢,但运行速度很快。然后它调用run_ten,但entry_exit已经编译,所以编译的形式被重用 - 所以它很快。

总之,我希望故障类似于

run_one: 0.74s compile + 1 x 0.01s run
run_ten: no compile + 10 x 0.01s run

要检查这一点,您只需确保在开始测试其速度之前调用该函数一次(以便 numba 编译它)。或者您可以设置标志来告诉 numba 提前编译。

验证这一点所需要做的就是将测试脚本更改为:

@timing
def run_one(arr):
    entry_exit(arr, limit=1.20, stop=0.50, tbe=5)

@timing
def run_ten(arr):
    for _ in range(10):
        entry_exit(arr, limit=1.20, stop=0.50, tbe=5)

arr = df[['open', 'high', 'low', 'close', 'entry', 'exit']].to_numpy()

# Run it once so that numba compiles it
entry_exit(arr, limit=1.20, stop=0.50, tbe=5)

# Use the compiled version
run_one(arr)
run_ten(arr)

推荐阅读