首页 > 解决方案 > 比较重叠的数组列表

问题描述

我有一个 [a, [b, c, d]] 形式的数组列表列表,类似于

dataset = [[1, array[1, 2, 3, 4]], [0.5, array[2, 2, 3, 5]], [1.5, array[4, 3, 2, 1]], ...]

我想比较每个数组并确定它们之间的重叠量。在上面的示例中,这意味着识别 [[1, array[1, 2, 3, 4]], [0.5, array[2, 2, 3, 5], ...]]。我关心重叠的价值观和立场。

如果超过某个阈值(例如 1/3)重叠,我想从数据集中消除系数较低的值。在上面的示例中,这将是第二个数组,其值为 0.5 而不是 1。

对于上面的列表,输出将是:

 [[1, array[1, 2, 3, 4]], [1.5, array[4, 3, 2, 1]], ...]

我已经设法整合了一个解决方案(如下),但速度很慢。我确信有更好的方法来解决这个问题,我只是不确定它是什么。

survivors = dataset
for i, pair in enumerate(dataset):
        keep_arr = [veto_duplicate(pair, dup) for dup in survivors]
        survivors = list(compress(survivors,keep_arr))
return survivors

def veto_duplicate(path1, path2):
        fractional_overlap = sum(path1[1] == path2[1])/len(path1[1])
        if fractional_overlap > 0.25 and fractional_overlap < 1:
                        if path1[0] < path2[0]:
                                return False
                        else:
                                return True
        else:
                return True

如果有人可以提出一种更快的方法来做到这一点,我将不胜感激。

编辑

所有内部数组的大小相同。最终结果不应包含重叠的数组。如果有三个(或更多)数组重叠,我只想保留系数最高的那个。

标签: pythonarraysnumpy

解决方案


这是一个更快的解决方案。希望能帮助到你:

dataset = [[1, np.array([1, 2, 3, 4])], [0.5, np.array([2, 2, 3, 5])], [1.5, np.array([4, 3, 2, 1])]]
#sort dataset by first element
dataset.sort(key=lambda x:x[0])
#unnormalize threshold to reduce division everytime
threshold = 1/3 * len(dataset[0][1])
survivors = dataset.copy()
id_to_del = []
for i, pair in enumerate(dataset):
    for dup in dataset[i+1:]:
        if sum(pair[1]==dup[1])>threshold:
            id_to_del.append(i)
            break

for i in id_to_del:
    del(survivors[i])
#converting to numpy and deleting might be faster if list is too long, but requires reformatting:
#survivors = np.delete(np.array(survivors), id_to_del, axis=0)

输出:

[[1, array([1, 2, 3, 4])], [1.5, array([4, 3, 2, 1])], ...]

推荐阅读