首页 > 解决方案 > 根据掩码从3d数组中提取切片的pythonic方法

问题描述

我有一个MxNxD数组 I 和一个二进制MxN掩码 M。

假设 M 中有k个 1。我想要的是提取一个kxD数组,其中包含与掩码中的 1 对应的所有 D 长度向量。

我可以通过调用numpy.nonzero()在 I 中获取这些向量的索引,但我找不到一个很好的紧凑方法来获取没有可怕循环的切片。

任何帮助都感激不尽。

标签: pythonarraysnumpyindexing

解决方案


我认为这就是你想要的:

In [283]: A = np.arange(24).reshape(2,3,4)
In [284]: M = np.array([[1,0,1],[0,1,0]],dtype=bool)
In [285]: A
Out[285]: 
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])
In [286]: M
Out[286]: 
array([[ True, False,  True],
       [False,  True, False]])
In [287]: I,J = np.nonzero(M)
In [288]: I,J
Out[288]: (array([0, 0, 1]), array([0, 2, 1]))
In [289]: A[I,J,:]
Out[289]: 
array([[ 0,  1,  2,  3],
       [ 8,  9, 10, 11],
       [16, 17, 18, 19]])

由于M掩盖了初始尺寸,因此可以简化为

A[np.nonzero(M)]

推荐阅读