首页 > 解决方案 > 具有多维索引列表的 Numpy 掩码

问题描述

如果我有数组

data = np.array([[0.75, 0.05, 0.1, 0.2],
                 [0.4, 0.3, 0.2, 0.1]])

labels = np.array([3,1])

如何将索引覆盖在数据上以便得到这个结果?

np.array([0.2, 0.3])

换句话说,数组的索引与labels数组匹配,data并且每个索引的值对应于从data数组中的行中获取的索引值。

我一直在尝试找到一种有效的方法来做到这一点,所以如果存在这种方法,最好是矢量化的并且更容易理解。

标签: pythonarraysnumpy

解决方案


除了@wjandrea 提到的索引方法之外,很容易看出您想要的项目将位于索引矩阵的对角线data[:,labels],或data.T[labels]

所以试试这个 -

data[:,labels].diagonal()

#OR

data.T[labels].diagonal()
array([0.2, 0.3])

您可以执行此操作的另一种方法是使用np.take_along_axiswhich 可让您沿轴选择特定索引。在这种情况下,该轴为 0。但您必须重塑标签以使每行有 1 个标签。

np.take_along_axis(data, labels[:,None], axis=1)
array([[0.2],
       [0.3]])

然后,您可以转置或使用 flatten 通过使用output.ravel()


推荐阅读