首页 > 解决方案 > 仅删除该行 3D numpy 数组中包含重复项的行

问题描述

我有一个像这样的 3D numpy 数组:

>>> a
array([[[0, 1, 2],
        [0, 1, 2],
        [6, 7, 8]],
       [[6, 7, 8],
        [0, 1, 2],
        [6, 7, 8]],
       [[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]]])

我只想删除那些本身包含重复项的行。例如,输出应如下所示:

>>> remove_row_duplicates(a)
array([[[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]]])

这是我正在使用的功能:

delindices = np.empty(0, dtype=int)

for i in range(len(a)):
    _, indices = np.unique(np.around(a[i], decimals=10), axis=0, return_index=True)

    if len(indices) < len(a[i]):

        delindices = np.append(delindices, i) 

a = np.delete(a, delindices, 0)

这很完美,但问题是我的数组形状就像(1000000,7,3)。for 循环在 python 中非常慢,这需要很多时间。我的原始数组也包含浮点数。任何有更好的解决方案或可以帮助我矢量化此功能的人?

标签: pythonnumpymultidimensional-array

解决方案


沿着每个2D block ie的行对其进行排序axis=1,然后沿着连续的行查找匹配的行,最后any沿着相同的行查找匹配axis=1-

b = np.sort(a,axis=1)
out = a[~((b[:,1:] == b[:,:-1]).all(-1)).any(1)]

示例运行说明

输入数组:

In [51]: a
Out[51]: 
array([[[0, 1, 2],
        [0, 1, 2],
        [6, 7, 8]],

       [[6, 7, 8],
        [0, 1, 2],
        [6, 7, 8]],

       [[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]]])

代码步骤:

# Sort along axis=1, i.e rows in each 2D block
In [52]: b = np.sort(a,axis=1)

In [53]: b
Out[53]: 
array([[[0, 1, 2],
        [0, 1, 2],
        [6, 7, 8]],

       [[0, 1, 2],
        [6, 7, 8],
        [6, 7, 8]],

       [[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]]])

In [54]: (b[:,1:] == b[:,:-1]).all(-1) # Look for successive matching rows
Out[54]: 
array([[ True, False],
       [False,  True],
       [False, False]])

# Look for matches along each row, which indicates presence
# of duplicate rows within each 2D block in original 2D array
In [55]: ((b[:,1:] == b[:,:-1]).all(-1)).any(1)
Out[55]: array([ True,  True, False])

# Invert those as we need to remove those cases
# Finally index with boolean indexing and get the output
In [57]: a[~((b[:,1:] == b[:,:-1]).all(-1)).any(1)]
Out[57]: 
array([[[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]]])

推荐阅读