首页 > 解决方案 > 消除不满足两个条件的数组行

问题描述

考虑数组 m 和 n。两者具有相同的形状。m 和 n 总是有偶数列,我添加了空格来强调每个数组行由PAIRS元素组成。

import numpy as np

m = np.array([[5,3,  6,7,  3,8],
              [5,4,  5,1,  4,5],
              [5,4,  2,4,  4,6],
              [2,2,  2,3,  8,5],
              [2,7,  8,7,  1,2],
              [2,7,  8,7,  3,2],
              [2,7,  8,7,  4,2]]) 

n = np.array([[1,3,  6,7,  1,12],
              [2,4,  2,9,  4,9],
              [1,5,  5,12, 9,1],
              [5,4,  5,6,  9,5],
              [5,4,  1,4,  1,5],
              [5,4,  1,7,  9,5],
              [5,4,  1,5  4,11]]) 

我的总体目标是从 m 中消除不符合两个 TEST 条件的行:

测试 1:只有当每一对与行中的其他对具有共同元素时, 我们才保留 m行。 Quang Hoang 在 2020 年 11 月 13 日已经非常有能力地回答了这个问题(https://stackoverflow.com/a/64814379/11188140)。我将其包含在此处是因为该代码将有助于后续的 TEST 2。

使用此代码,m 的第一行被拒绝,因为对 (6,7) 在其他行对中没有共同元素。m 的第 4 行也被拒绝,因为它的最后一对在其他行对中没有共同元素。较早的代码可以完美运行,其输出如下:

a = m.reshape(m.shape[0],-1,2)
mask = ~np.eye(a.shape[1], dtype=bool)[...,None]

is_valid = (((a[...,None,:]==a[:,None,...])&mask).any(axis=(-1,-2))
            |((a[...,None,:]==a[:,None,:,::-1])&mask).any(axis=(-1,-2))
           ).all(-1)

m[is_valid]

输出 m 为:

              [[5,4,  5,1,  4,5],
               [5,4,  2,4,  4,6],
               [2,7,  8,7,  1,2],
               [2,7,  8,7,  3,2],
               [2,7,  8,7,  4,2]]

测试 2:仅当数组 n 中的行(即:相同的索引)具有与 m 中相同的配对匹配时, 我们才保留 m的行。 n 行也可能有“额外的”对匹配,但它 必须包括 m 具有的对匹配。

三个例子说明了这一点:

a) 5th row of m and n:  '[2,7,  8,7,  1,2]' and '[5,4,  1,4,  1,5]'
   In m, the 1st & 2nd pairs share an element, and the 1st and & 3rd pairs share an element.
   n has both of these matchings, so **TEST 2 PASSES**, and we keep this m row. 
   The fact that 2nd & 3rd pairs of n also share an element is immaterial.                                                                          
                                                                                                
b) 6th row of m and n:  '[2,7,  8,7,  3,2]' and '[5,4,  1,7,  9,5]'
   In m, the 1st & 2nd pairs share an element, and the 1st & 3rd pairs share an element.
   n DOES NOT have BOTH of these matchings, so **TEST 2 FAILS**.  The row may be eliminated from m (and from n too if that's needed)

c) 3rd row of m and n:  '[5,4,  2,4,  4,6]' and '[1,5,  5,12, 9,1]'
   In m, the 1st & 2nd pairs share an element, the 2nd & 3rd pairs share an element, and the 1st & 3rd pairs share an element.
   n lacks the 2nd & 3rd pair sharing that m has, so **TEST 2 FAILS**.  The row may be eliminated from m (and from n too if that's needed)

通过两项测试后,最终输出 m 为:

              [[5,4,  5,1,  4,5],
               [2,7,  8,7,  1,2],
               [2,7,  8,7,  4,2]]

标签: pythonarraysnumpy

解决方案


这是您提到的上述 2 个测试的矢量化解决方案。我也对 test1 的代码进行了一些修改,以便它很好地流入 test2。

import numpy as np

m = np.array([[5,3,  6,7,  3,7],
              [5,4,  5,1,  4,5],
              [5,4,  2,4,  4,6],
              [2,2,  2,3,  8,5],
              [2,7,  8,7,  1,2],
              [2,7,  8,7,  3,2],
              [2,7,  8,7,  4,2]]) 

n = np.array([[1,3,  6,7,  1,12],
              [2,4,  2,9,  4,9],
              [1,5,  5,12, 9,1],
              [5,4,  5,6,  9,5],
              [5,4,  1,4,  1,5],
              [5,4,  1,7,  9,5],
              [5,4,  1,5,  4,11]])


mm = m.reshape(m.shape[0],-1,2) #(7,3,2)
nn = n.reshape(m.shape[0],-1,2) #(7,3,2)

#broadcast((7,3,1,2,1), (7,1,3,1,2)) -> (7,3,3,2,2) -> (7,3,3)
matches_m = np.any(mm[:,:,None,:,None] == mm[:,None,:,None,:], axis=(-1,-2))  #(7,3,3)
matches_n = np.any(nn[:,:,None,:,None] == nn[:,None,:,None,:], axis=(-1,-2))  #(7,3,3)

mask = ~np.eye(mm.shape[1], dtype=bool)[None,:]  #(1,3,3)

is_valid1 = np.all(np.any(matches_m&mask, axis=-1), axis=-1)  #(7,)
is_valid2 = np.all(~(matches_m^matches_n)|matches_n, axis=(-1,-2)) #(7,)

m[is_valid1 & is_valid2]
array([[5, 4, 5, 1, 4, 5],
       [2, 7, 8, 7, 1, 2],
       [2, 7, 8, 7, 4, 2]])

解释 -

测试 1

  1. 第一步是重塑 n 和 m(7,3,2)以便我们可以广播axis=1
  2. 接下来,我们需要比较最后一个轴的元素,以获得每行元素的叉积。所以预期的输出是(7 rows, 3X3 cross product between elements).
  3. 但是为了比较,我还必须在最后一个轴(2x2)上广播。这意味着我需要一个(7,3,3,2,2). 这可以通过广播 2 个(7,3,1,2,1)和数组来完成(7,1,3,1,2)
  4. 最后,最后 2 个轴上的 any() 将为您提供一个(7,3,3)位置,对于 7 行中的每一行,您将每个元素与另一个元素进行比较,如果其中任何一个具有公共元素,则返回 True。THIS IS ALSO THE ARRAY THAT CONTAINS THE MATCHES AND WILL BE IMPORTANT FOR TEST2!
  5. 接下来,由于相同的元素比较总是会给出 True,你想忽略对角线,所以为它创建一个掩码。
  6. 使用掩码,获取哪些元素与除自身之外的其他元素至少有 1 次匹配,这解决了 TEST1。

测试 2

  1. 应用相同的前 4 个步骤来获取 n 的匹配矩阵(7,3,3)
  2. 这里,any match that exists in m MUST exist in n,但是any match that exists in n is not needed in m。所需的逻辑是 -
a = np.array([True, False, True, False])
b = np.array([True, True, False, False])

~(a^b)|b
#array([ True,  True, False,  True])
  1. 将此应用于 2 个(7,3,3)匹配项,如果在各自的 中甚至存在单个 false (3,3),则表明 m 中的对之间的匹配没有反映在 n 中。所以你到达False那里。这过度axis = -1, -2导致(7,)

  2. 这为您提供了 TEST2。

希望这能解决您的问题。


推荐阅读