python - 使用 numba 加速 dtaidistance 键功能
问题描述
该DTAIDistance
包可用于查找k
输入查询的最佳匹配。但不能用于多维输入查询。此外,我想k
在一次运行中找到许多输入查询的最佳匹配。
我修改了这个DTAIDistance
函数,使它可以用于搜索多维多查询的子序列。我使用njit
with parallel 来加快处理速度,即 p_calc 函数将 numba-parallel 应用于每个输入查询。但我发现与只是简单地逐一循环输入查询(即 calc 函数)相比,并行计算似乎并没有加快计算速度。
import time
from tqdm import tqdm
from numba import njit, prange
import numpy as np
inf = np.inf
argmin=np.argmin
@njit(fastmath=True, nogil=True, error_model="numpy", cache=True, parallel=False)
def p_calc(d, dtw, s1, s2, r, c, psi_1b, psi_1e, psi_2b, psi_2e, window, max_step, max_dist, penalty, psi_neg):
n_series = s1.shape[1]
ndim = s1.shape[2]
# s1 = np.ascontiguousarray(s1)#.shape
# s2 = np.ascontiguousarray(s2)#.shape
# dtw = np.full((n_series,r + 1, c + 1), np.inf,dtype=s1.dtype) # cmath.inf
# d = np.full((n_series), np.inf,dtype=s1.dtype) # cmath.inf
for i in range(psi_2b + 1):
dtw[:, 0, i] = 0
for i in range(psi_1b + 1):
dtw[:, i, 0] = 0
for nn in prange(n_series):
print('im alive...')
i0 = 1
i1 = 0
sc = 0
ec = 0
smaller_found = False
ec_next = 0
for i in range(r):
i0 = i
i1 = i + 1
j_start = max(0, i - max(0, r - c) - window + 1)
j_end = min(c, i + max(0, c - r) + window)
if sc > j_start:
j_start = sc
smaller_found = False
ec_next = i
for j in range(j_start, j_end):
val = 0
tmp = ((s1[i, nn] - s2[j]) ** 2)
# tmp = (np.abs(s1[i, nn] - s2[j, 0]))
for nd in range(ndim):
val += tmp[nd]
d[nn] = val
# d = np.sum(np.abs(s1[i] - s2[j]) ) # multi-d
if max_step is not None and d[nn] > max_step:
continue
# print(i, j + 1 - skip, j - skipp, j + 1 - skipp, j - skip)
dtw[nn, i1, j + 1] = d[nn] + min(dtw[nn, i0, j],
dtw[nn, i0, j + 1] + penalty,
dtw[nn, i1, j] + penalty)
# dtw[i + 1, j + 1 - skip] = d + min(dtw[i + 1, j + 1 - skip], dtw[i + 1, j - skip])
if dtw[nn, i1, j + 1] > max_dist:
if not smaller_found:
sc = j + 1
if j >= ec:
break
else:
smaller_found = True
ec_next = j + 1
ec = ec_next
# Decide which d to return
dtw[nn] = np.sqrt(dtw[nn])
if psi_1e == 0 and psi_2e == 0:
d[nn] = dtw[nn, i1, min(c, c + window - 1)]
else:
ir = i1
ic = min(c, c + window - 1)
if psi_1e != 0:
vr = dtw[nn, ir:max(0, ir - psi_1e - 1):-1, ic]
mir = np.argmin(vr)
vr_mir = vr[mir]
else:
mir = ir
vr_mir = inf
if psi_2e != 0:
vc = dtw[nn, ir, ic:max(0, ic - psi_2e - 1):-1]
mic = np.argmin(vc)
vc_mic = vc[mic]
else:
mic = ic
vc_mic = inf
if vr_mir < vc_mic:
if psi_neg:
dtw[nn, ir:ir - mir:-1, ic] = -1
d[nn] = vr_mir
else:
if psi_neg:
dtw[nn, ir, ic:ic - mic:-1] = -1
d[nn] = vc_mic
if max_dist and d[nn] ** 2 > max_dist:
# if max_dist and d[nn] > max_dist:
d[nn] = inf
return d, dtw
@njit(fastmath=True, nogil=True) # Set "nopython" mode for best performance, equivalent to @njit
def calc(s1, s2, r, c, psi_1b, psi_1e, psi_2b, psi_2e, window, max_step, max_dist, penalty, psi_neg):
dtw = np.full((r + 1, c + 1), np.inf) # cmath.inf
for i in range(psi_2b + 1):
dtw[0, i] = 0
for i in range(psi_1b + 1):
dtw[i, 0] = 0
i0 = 1
i1 = 0
sc = 0
ec = 0
smaller_found = False
ec_next = 0
for i in range(r):
i0 = i
i1 = i + 1
j_start = max(0, i - max(0, r - c) - window + 1)
j_end = min(c, i + max(0, c - r) + window)
if sc > j_start:
j_start = sc
smaller_found = False
ec_next = i
for j in range(j_start, j_end):
# d = (s1[i] - s2[j]) ** 2# 1-d
d = np.sum((s1[i] - s2[j]) ** 2) # multi-d
# d = np.sum(np.abs(s1[i] - s2[j]) ) # multi-d
if max_step is not None and d > max_step:
continue
dtw[i1, j + 1] = d + min(dtw[i0, j],
dtw[i0, j + 1] + penalty,
dtw[i1, j] + penalty)
if dtw[i1, j + 1] > max_dist:
if not smaller_found:
sc = j + 1
if j >= ec:
break
else:
smaller_found = True
ec_next = j + 1
ec = ec_next
# Decide which d to return
dtw = np.sqrt(dtw)
if psi_1e == 0 and psi_2e == 0:
d = dtw[i1, min(c, c + window - 1)]
else:
ir = i1
ic = min(c, c + window - 1)
if psi_1e != 0:
vr = dtw[ir:max(0, ir - psi_1e - 1):-1, ic]
mir = argmin(vr)
vr_mir = vr[mir]
else:
mir = ir
vr_mir = inf
if psi_2e != 0:
vc = dtw[ir, ic:max(0, ic - psi_2e - 1):-1]
mic = argmin(vc)
vc_mic = vc[mic]
else:
mic = ic
vc_mic = inf
if vr_mir < vc_mic:
if psi_neg:
dtw[ir:ir - mir:-1, ic] = -1
d = vr_mir
else:
if psi_neg:
dtw[ir, ic:ic - mic:-1] = -1
d = vc_mic
if max_dist and d * d > max_dist:
d = inf
return d, dtw
mydtype = np.float32
series1 = np.random.random((16, 30, 2)).astype(mydtype)
series2 = np.random.random((100000, 2)).astype(mydtype)
n_series = series1.shape[1]
r = series1.shape[0]
c = series2.shape[0]
dtw = np.full((n_series, r + 1, c + 1), np.inf, dtype=mydtype) # cmath.inf
d = np.full((n_series), np.inf, dtype=mydtype) # cmath.inf
time1 = time.time()
d, dtw1 = p_calc(d, dtw, series1, series2, series1.shape[0], series2.shape[0], 0, 0,
series2.shape[0], series2.shape[0], series2.shape[0], np.inf, np.inf, 0.01, False)
print(time.time() - time1)
time1 = time.time()
for ii in tqdm(range(series1.shape[1])):
d, dtw1 = calc( series1[:, ii, :], series2, series1.shape[0], series2.shape[0], 0, 0,
series2.shape[0], series2.shape[0], series2.shape[0], np.inf, np.inf, 0.01, False)
print(time.time() - time1)# this one is faster
如何加速 calc 函数或 p_calc 函数,以便计算多维多查询的动态时间规整路径?
感谢您的回答,然后我修改了代码以进行简化。我删除了 np.sum 部分并使用循环,我可以获得另一个加速。对进一步加速有什么建议吗?
import time
from numba import njit, prange
import numpy as np
inf = np.inf
argmin=np.argmin
@njit(fastmath=True, nogil=True, error_model="numpy", cache=False, parallel=True)
def p_calc(d, dtw, s1, s2, r, c, psi_1b, psi_1e, psi_2b, psi_2e, window, max_step, max_dist, penalty, psi_neg):
n_series = s1.shape[1]
ndim = s1.shape[2]
for nn in prange(n_series):
for i in range(r):
j_start = 0
j_end = c
for j in range(j_start, j_end):
val = 0
# tmp = ((s1[i, nn] - s2[j]) ** 2)
# tmp = (np.abs(s1[i, nn] - s2[j, 0]))
for nd in range(ndim):
tmp = ((s1[i, nn,nd] - s2[j,nd]) ** 2)
val += tmp
d[nn] = val
return d, dtw
@njit(fastmath=True, nogil=True) # Set "nopython" mode for best performance, equivalent to @njit
def calc(dtw,s1, s2, r, c, psi_1b, psi_1e, psi_2b, psi_2e, window, max_step, max_dist, penalty, psi_neg):
ndim = s1.shape[-1]
for i in range(r):
j_start = 0
j_end = c
for j in range(j_start, j_end):
d = 0
for kk in range(ndim):
d += (s1[i, kk] - s2[j, kk]) ** 2
return d, dtw
mydtype = np.float32
series1 = np.random.random((16, 300, 2)).astype(mydtype)
series2 = np.random.random((1000000, 2)).astype(mydtype)
n_series = series1.shape[1]
r = series1.shape[0]
c = series2.shape[0]
dtw = np.full((n_series, r + 1, c + 1), np.inf, dtype=mydtype) # cmath.inf
d = np.full((n_series), np.inf, dtype=mydtype) # cmath.inf
time1 = time.time()
# assert 1==2
# dtw[:,series2.shape[0]]
d1, dtw1 = p_calc(d, dtw, series1, series2, series1.shape[0], series2.shape[0], 0, 0, series2.shape[0], series2.shape[0], series2.shape[0], np.inf, np.inf, 0.01, False)
print(time.time() - time1)
# assert 1==2
time1 = time.time()
dtw = np.full(( r + 1, c + 1), np.inf, dtype=mydtype) # cmath.inf
for ii in (range(series1.shape[1])):
d2, dtw2 = calc( dtw,series1[:, ii, :], series2, series1.shape[0], series2.shape[0], 0, 0,
series2.shape[0], series2.shape[0], series2.shape[0], np.inf, np.inf, 0.01, False)
print(time.time() - time1)# this one is faster
np.allclose(dtw1[-1],dtw2)
np.allclose(d1[-1],d2)
编辑:
我发现以下代码的性能如果使用pass
或break
. 我不明白为什么?
@njit(fastmath=True, nogil=True)
def kbest_matches(matching,k=4000):
ki = 0
while ki < k:
best_idx =np.argmin(matching)# np.argmin(np.arange(10000000))#
if best_idx == 0 :
# pass
break
ki += 1
return 0
ss= np.random.random((1575822,))
time1 = time.time()
pp = kbest_matches(ss)
print(time.time() - time1)
解决方案
我假设这两种实现的代码都是正确的并且经过仔细检查(否则基准将毫无意义)。
问题可能来自函数的编译时间。事实上,第一次调用比下一次调用慢得多,即使使用cache=True
. 这对于并行实现尤其重要,因为编译并行 Numba 代码通常较慢(因为它更复杂)。避免这种情况的最佳解决方案是通过向 Numba 提供类型来提前编译 Numba 函数。
除此之外,仅对计算进行一次基准测试通常被认为是一种不好的做法。好的基准执行多次迭代并删除第一个迭代(或单独考虑它们)。实际上,第一次执行代码时可能会出现其他几个问题:CPU缓存(和 TLB)很冷,CPU频率在执行过程中会发生变化,并且在程序刚启动时可能会变小,可能需要页面错误需要等
在实践中,我无法重现该问题。实际上,p_calc
在我的 6 核机器上快 3.3 倍。当基准测试在 5 次迭代的循环中完成时,并行实现的测量时间要小得多:大约 13 次(这对于在 6 核机器上使用 6 个线程的并行实现实际上是可疑的)。
推荐阅读
- java - 如何在登录主页后重定向用户并使用 Spring Security 抛出 200 而不是 302?
- python - 将 str 列表映射到 int 列表
- python - 我可以将文件位置缩短为 .py 文件位置吗
- c# - 在标签上显示所选 ListBoxItem 的值
- javascript - 如何从 mysql 返回 JSON 对象?
- react-native - React Native 的 createStackNavigator()
- python - 正则表达式:匹配用逗号分隔的结构的另一种方法
- python - 无法使用 Homebrew 在 macOS 中降级 python 版本
- regex - 我如何正则表达式匹配这 3 个场景?
- java - How to add HTTP Response Header by a value of Response Payload in Spring Boot