首页 > 解决方案 > Numpy:检索过滤数组中原始列的最有效方法

问题描述

假设我们有一个arr1形状为 (10,4) 的二维数组:

import numpy as np
np.set_printoptions(suppress=True)

np.random.seed(0)
a = np.random.rand(10, 3)
np.random.seed(0)
b = np.random.randint(24, size=10)
b = b.reshape(len(b),1)
arr1 = np.hstack((a, b))
print(arr1)

arr1看起来像这样:

[[ 0.5488135   0.71518937  0.60276338 12.        ]
[ 0.54488318  0.4236548   0.64589411 15.        ]
[ 0.43758721  0.891773    0.96366276 21.        ]
[ 0.38344152  0.79172504  0.52889492  0.        ]
[ 0.56804456  0.92559664  0.07103606  3.        ]
[ 0.0871293   0.0202184   0.83261985  3.        ]
[ 0.77815675  0.87001215  0.97861834  7.        ]
[ 0.79915856  0.46147936  0.78052918  9.        ]
[ 0.11827443  0.63992102  0.14335329 19.        ]
[ 0.94466892  0.52184832  0.41466194 21.        ]]

现在,执行了一些外部过程,因此我们丢失了最后一列中的信息,并且过滤了一些行:

arr2 = arr1[:,:3]
np.random.seed(0)
filter_arr = np.random.choice(10, size=6, replace=False)
arr2 = arr2[filter_arr]
print(arr2)

结果,我们得到了以下数组arr2

[[0.43758721 0.891773   0.96366276]
[0.11827443 0.63992102 0.14335329]
[0.56804456 0.92559664 0.07103606]
[0.94466892 0.52184832 0.41466194]
[0.54488318 0.4236548  0.64589411]
[0.77815675 0.87001215 0.97861834]]

目标是根据前三列中的值有效地检查哪些行保留在中arr2,并将第四列中的值arr1相加arr2。当然,filter_arr上一步中的内容将是完全未知的。

预期的结果是这样的:

[[0.43758721 0.891773  0.96366276 21.        ]
[0.11827443 0.63992102 0.14335329 19.        ]
[0.56804456 0.92559664 0.07103606  3.        ]
[0.94466892 0.52184832 0.41466194 21.        ]
[0.54488318 0.4236548  0.64589411 15.        ]
[0.77815675 0.87001215 0.97861834  7.        ]]

谢谢。

PS如果您为这个问题想出一个更好的标题,请告诉我更改它以便对其他用户更有用。

标签: pythonnumpy

解决方案


如果您可以将广播放入内存中,则可以比较 with 的前 3 列,arr1然后arr2numpy.all之后(沿右轴)使用 numpy.argmax 来检索 from 每一行的索引arr2最后arr1使用这些索引来获取的最后一列,arr1并与arr2.

正如 Jérôme Richard 指出的那样,使用==比较不安全,您可以使用numpy.isclose函数并根据需要自定义公差。

关于具有np.nan值的可能性,np.isclose还接受equal_nan默认情况下为 False 但可以设置为 True 以便np.isclose(np.nan, np.nan, equal_nan=True)返回的参数True(如果这是您想要的行为)。

import numpy as np

idx = np.argmax(np.all(np.isclose(arr1[:, :3, None],arr2.T[None, :, :]), axis=1), axis=0)

filtered_last_col = arr1[idx, -1, None]

np.hstack([arr2, filtered_last_col])
array([[ 0.43758721,  0.891773  ,  0.96366276, 21.        ],
       [ 0.11827443,  0.63992102,  0.14335329, 19.        ],
       [ 0.56804456,  0.92559664,  0.07103606,  3.        ],
       [ 0.94466892,  0.52184832,  0.41466194, 21.        ],
       [ 0.54488318,  0.4236548 ,  0.64589411, 15.        ],
       [ 0.77815675,  0.87001215,  0.97861834,  7.        ]])

推荐阅读