首页 > 解决方案 > Numpy: How to select row entries in a 2d array by column vector

问题描述

How can I retrieve a column vector from a 2d array given an indicator column vector?

Suppose I have

X = np.array([[1, 4, 6],
              [8, 2, 9],
              [0, 3, 7],
              [6, 5, 1]])

and

S = np.array([0, 2, 1, 2])

Is there an elegant way to get from X and S the result array([1, 9, 3, 1]), which is equivalent to

np.array([x[s] for x, s in zip(X, S)])

标签: numpyindexing

解决方案


您可以使用以下方法实现此目的np.take_along_axis

>>> np.take_along_axis(X, S[..., None], axis=1)
array([[1],
       [9],
       [3],
       [1]])

您需要确保两个数组参数的形状相同(或者可以应用广播),因此需要S[..., None]广播。

当然,您可以使用[:, 0]切片重塑返回值。

>>> np.take_along_axis(X, S[..., None], axis=1)[:, 0]
array([1, 9, 3, 1])

或者,您可以只使用带有排列的索引:

>>> X[np.arange(len(S)), S[np.arange(len(S))]]
array([1, 9, 3, 1])

我相信这也相当于np.diag(X[:, S])但有不必要的复制......


推荐阅读