首页 > 解决方案 > 如果数组列表的长度非常大,如何以更快的方法检查输入数组是否与数组列表最相似?

问题描述

我有一个包含许多列表的列表。我使用 dtw 方法使用 fastdtw python 包来计算输入列表和列表内的列表之间的距离。这给了我一个距离列表,我从中选择最小值并将其估计为最接近输入数组的距离。此过程有效,但如果列表的数量和长度很大,则它会占用大量 CPU 资源且耗时。

from fastdtw import fastdtw
import scipy.spatial.distance as ssd

inputlist = [1,2,3,4,5]
complelte_list = [[1,1,3,9,1],[1,2,6,4],[9,8,7,4,2]]
dst = []
for arr in complete_lists:
   distance, path = fastdtw(arr,inputlist,dist=ssd.euclidean)
   dst.append(distance)

标签: pythonpython-3.xdtw

解决方案


如果您需要最近的距离,而不是所有距离,请构建一棵树,例如

from sklearn.neighbors import BallTree
import numpy as np

inputlist = [1,2,3,4,5]
complelte_list = [[1,1,3,9,1],[1,2,6,4,5],[9,8,7,4,2]]

tree = BallTree(np.array(complelte_list), leaf_size=10, metric='euclidean')

并查询

distance, index = tree.query(np.expand_dims(np.array(inputlist),axis=0), k=1, return_distance=True)

哪个返回distance最接近k=1的,也返回index,例如

print('Most similar to inputlist is')
print( complelte_list[ index[0][0] ] )

如果速度很重要,您可以调整leaf_size=10并尝试适合您的尺寸。构建树也需要时间,因此如果这对您的情况有意义,请确保这也是您的基准测试的一部分。


推荐阅读