首页 > 解决方案 > Numpy:使用 2d 数组沿 3d 数组的轴选择

问题描述

我已经为此苦苦挣扎了几个小时,无法完全理解它。设置是这样的:

    A.shape # (T,N,K)
    B.shape #   (L,K) L < N

2D B 数组的 K 列中的每一列都索引同一 K 行的 N 列之一。我可以通过以下方式轻松抓取任何特定的 k 切片

    A[:,B[:,k],k].shape # (T,L)

但是,循环 K 并不理想,因为 A 是一个非常大的矩阵

我敢肯定有人有一个非常简单的答案,但我很难过。

编辑:我还应该补充一点,我需要保留 A 矩阵的 3D 结构。我想出了如何获取单个值,但只能在 (TxLxK,) 数组中。

标签: pythonnumpy

解决方案


您可以使用np.take_along_axis

np.take_along_axis(A,B[None,...],axis=1)

例如,

A = np.linspace(1,24,24).reshape(3,4,2)
B = np.repeat([[0,1]],3,axis=0)

np.take_along_axis(A,B[None,...],axis=1)

结果是

array([[[ 1.,  4.],
        [ 1.,  4.],
        [ 1.,  4.]],

       [[ 9., 12.],
        [ 9., 12.],
        [ 9., 12.]],

       [[17., 20.],
        [17., 20.],
        [17., 20.]]])

推荐阅读