python - Numpy:检索过滤数组中原始列的最有效方法
问题描述
假设我们有一个arr1
形状为 (10,4) 的二维数组:
import numpy as np
np.set_printoptions(suppress=True)
np.random.seed(0)
a = np.random.rand(10, 3)
np.random.seed(0)
b = np.random.randint(24, size=10)
b = b.reshape(len(b),1)
arr1 = np.hstack((a, b))
print(arr1)
arr1
看起来像这样:
[[ 0.5488135 0.71518937 0.60276338 12. ]
[ 0.54488318 0.4236548 0.64589411 15. ]
[ 0.43758721 0.891773 0.96366276 21. ]
[ 0.38344152 0.79172504 0.52889492 0. ]
[ 0.56804456 0.92559664 0.07103606 3. ]
[ 0.0871293 0.0202184 0.83261985 3. ]
[ 0.77815675 0.87001215 0.97861834 7. ]
[ 0.79915856 0.46147936 0.78052918 9. ]
[ 0.11827443 0.63992102 0.14335329 19. ]
[ 0.94466892 0.52184832 0.41466194 21. ]]
现在,执行了一些外部过程,因此我们丢失了最后一列中的信息,并且过滤了一些行:
arr2 = arr1[:,:3]
np.random.seed(0)
filter_arr = np.random.choice(10, size=6, replace=False)
arr2 = arr2[filter_arr]
print(arr2)
结果,我们得到了以下数组arr2
:
[[0.43758721 0.891773 0.96366276]
[0.11827443 0.63992102 0.14335329]
[0.56804456 0.92559664 0.07103606]
[0.94466892 0.52184832 0.41466194]
[0.54488318 0.4236548 0.64589411]
[0.77815675 0.87001215 0.97861834]]
目标是根据前三列中的值有效地检查哪些行保留在中arr2
,并将第四列中的值arr1
相加arr2
。当然,filter_arr
上一步中的内容将是完全未知的。
预期的结果是这样的:
[[0.43758721 0.891773 0.96366276 21. ]
[0.11827443 0.63992102 0.14335329 19. ]
[0.56804456 0.92559664 0.07103606 3. ]
[0.94466892 0.52184832 0.41466194 21. ]
[0.54488318 0.4236548 0.64589411 15. ]
[0.77815675 0.87001215 0.97861834 7. ]]
谢谢。
PS如果您为这个问题想出一个更好的标题,请告诉我更改它以便对其他用户更有用。
解决方案
如果您可以将广播放入内存中,则可以比较 with 的前 3 列,arr1
然后arr2
在numpy.all之后(沿右轴)使用 numpy.argmax 来检索 from 每一行的索引,arr2
最后arr1
使用这些索引来获取的最后一列,arr1
并与arr2
.
正如 Jérôme Richard 指出的那样,使用==
比较不安全,您可以使用numpy.isclose函数并根据需要自定义公差。
关于具有np.nan
值的可能性,np.isclose
还接受equal_nan
默认情况下为 False 但可以设置为 True 以便np.isclose(np.nan, np.nan, equal_nan=True)
返回的参数True
(如果这是您想要的行为)。
import numpy as np
idx = np.argmax(np.all(np.isclose(arr1[:, :3, None],arr2.T[None, :, :]), axis=1), axis=0)
filtered_last_col = arr1[idx, -1, None]
np.hstack([arr2, filtered_last_col])
array([[ 0.43758721, 0.891773 , 0.96366276, 21. ],
[ 0.11827443, 0.63992102, 0.14335329, 19. ],
[ 0.56804456, 0.92559664, 0.07103606, 3. ],
[ 0.94466892, 0.52184832, 0.41466194, 21. ],
[ 0.54488318, 0.4236548 , 0.64589411, 15. ],
[ 0.77815675, 0.87001215, 0.97861834, 7. ]])
推荐阅读
- office-js - 清单中的语言覆盖不起作用
- angular - 当我在 Android 模拟器上运行构建时,Ionic App 在登录后崩溃
- macos - Mac vscode:使用 ctrl + 鼠标单击的多行
- javascript - 如何在firebase中从一个集合访问另一个集合
- jmeter - 当 jmeter 等待先前的响应发送新请求时,如何在 jmeter 中表示现实世界?
- python - 如何在 python 中模拟 Firestore where 函数?
- ios - iOS 设备 Flutter 上的 Flutter image_picker 问题
- php - Guzzle 在 16kB 后截断标题
- python - 使用带有深度学习的python预处理训练函数的ktrain这个错误的含义是什么
- microsoft-graph-api - MS Graph API 是否支持 Microsoft 365 Defender