首页 > 解决方案 > 如何创建一个 3D (shape=m,n,o) 数组,其中包含沿特定(第三)轴的数组,索引由 2D 数组(shape=m,n)给出?

问题描述

假设我有一个形状为 (2, 2) 的二维数组,其中包含索引

x = np.array([[2, 0], [3, 1]])

我想做的是创建一个形状为 (2, 2, 4) 的 3D 数组,该数组沿第三个轴的值为 1,它们的位置由 给出x,因此:

y = np.zeros(shape=(2,2,4))
myfunc(array=y, indices=x, axis=2)

array([[[0, 0, 1, 0],
        [1, 0, 0, 0]],
       [[0, 0, 0, 1],
        [0, 1, 0, 0]]])

到目前为止,我还没有找到任何索引方法。一个for循环也许可以做到这一点,但我确信有一个更快的矢量化方法。

标签: pythonarraysnumpyindexing

解决方案


您要查找的内容称为高级索引。要正确使用整数数组进行索引,您需要有一组广播到正确形状的数组。由于x已经与两个维度对齐,因此您只需要制作带有沿每个轴的索引的二维数组。np.ogrid对此有所帮助,因为它创建了广播到正确形状的最小范围数组:

a, b = np.ogrid[:2, :2]
y[a, b, x] = 1

结果ogrid相当于

a = np.arange(2).reshape(-1, 1)
b = np.arange(2).reshape(1, -1)

或者

a = np.arange(2)[:, None]
b = np.arange(2)[None, :]

你也可以写一个单行:

y[(*tuple(slice(None, n) for n in x.shape), x)] = 1

推荐阅读