首页 > 解决方案 > 在 numpy 中使用另一个索引数组访问数组

问题描述

我有一个形状为 [10, 10, 10] 的数组 A。

现在我想通过使用另一个包含索引的形状为 [10, 10, 10, 3] 的数组 B 来访问该数组。

作为输出,我想得到一个形状为 [10, 10, 10] 的数组 C。因此 B 中的每个索引都被 A 中的相应值“替换”。

不幸的是,尽管迭代索引数组并为每个元素逐步获取 A 的每个相应值,但我找不到合适的答案来解决该问题。我正在寻找更有效的解决方案。

非常感谢!

标签: pythonarrayspython-3.xnumpy

解决方案


这里有两种方法可以做你想做的事。第一个使用循环,第二个不使用。第二个速度快了大约 10 倍。

解决方案 1 - 循环

import numpy as np
a = np.random.normal(0,1,(10,10,10)) # original array

b = np.random.randint(0,10, (10,10,10,3)) # array of positions

c = np.zeros((10,10, 10)) # new array

for i in range(a.shape[0]):
    for j in range(a.shape[1]):
        for k in range(a.shape[2]):
            c[i,j,k]=a[tuple(b[i,j,k])]

这种方法在我的笔记本电脑上大约需要 4ms

以上可以作为比较的基准。现在这里是用数组切片和没有循环完成的同样的事情。

解决方案 2 - 没有循环。更高效的数组切片

a_original_shape = a.shape
# reshape b to make it (10**3, 3) 
# b_flat[:,0] hold all the i coords, b_flat[:,0] holds j coords etc
b_flat = b.reshape( (np.product(a_original_shape),)+(3,) )

# slice out the values we want from a. This gives us a 1D array
c2_flat = a[[b_flat[:,i] for i in range(3)]]

# reshape it back to the original form. 
# All values will be the correct place in this new array
c2 = c2_flat.reshape(a_original_shape)

这种方法在我的笔记本电脑上大约需要 0.5ms

您可以使用以下方法检查这些方法是否提供相同的东西

np.all(c2==c)
True

这里的第二个解决方案大约需要第一个解决方案的 10% 的时间


推荐阅读