python - numpy索引ndarray[(4, 2), (5, 3)]的解释
问题描述
问题
请帮助理解将元组 (i, j) 索引到 ndarray 中的 Numpy 索引的设计决策或合理性。
背景
当索引为单个元组 (4, 2) 时,则 (i=row, j=column)。
shape = (6, 7)
X = np.zeros(shape, dtype=int)
X[(4, 2)] = 1
X[(5, 3)] = 1
print("X is :\n{}\n".format(X))
---
X is :
[[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 1 0 0 0 0] <--- (4, 2)
[0 0 0 1 0 0 0]] <--- (5, 3)
但是,当索引是多个元组 (4, 2), (5, 3) 时,则 (i=row, j=row) 为 (4, 2) 和 (i=column, j=column) 为 (5, 3)。
shape = (6, 7)
Y = np.zeros(shape, dtype=int)
Y[(4, 2), (5, 3)] = 1
print("Y is :\n{}\n".format(Y))
---
Y is :
[[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 1 0 0 0] <--- (2, 3)
[0 0 0 0 0 0 0]
[0 0 0 0 0 1 0] <--- (4, 5)
[0 0 0 0 0 0 0]]
这意味着您正在构建一个二维数组R,这样
R=A[B, C]
. 这意味着 r ij =a b ij c ij的值。所以这意味着位于
R[0,0]
的项目是A
作为行索引B[0,0]
和列索引的项目C[0,0]
。该项目R[0,1]
是A
具有行索引B[0,1]
和列索引C[0,1]
等的项目。
multi_index:整数数组的元组,每个维度一个数组。
为什么不总是(i=row, j=column)?如果总是 (i=row, j=column) 会发生什么?
更新
通过 Akshay 和 @DaniMesejo 的回答,了解:
X[
(4), # dimension 2 indices with only 1 element
(2) # dimension 1 indices with only 1 element
] = 1
Y[
(4, 2, ...), # dimension 2 indices
(5, 3, ...) # dimension 1 indices (dimension 0 is e.g. np.array(3) whose shape is (), in my understanding)
] = 1
解决方案
很容易理解它是如何工作的(以及这个设计决策背后的动机)。
Numpy 将其 ndarray 存储为连续的内存块。每个元素在前一个元素之后每隔 n 个字节以顺序方式存储。
(引用自这篇优秀的 SO 帖子的图片)
因此,如果您的 3D 阵列看起来像这样 -
然后在内存中它的存储为 -
strides
当检索一个元素(或一个元素块)时,NumPy 计算它需要遍历多少(字节)才能获取下一个元素in that direction/axis
。因此,对于上面的示例,axis=2
它必须遍历 8 个字节(取决于datatype
),但axis=1
它必须遍历8*4
字节,并且axis=0
它需要8*8
字节。
考虑到这一点,让我们看看您要做什么。
print(X)
print(X.strides)
[[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[0 0 1 0 0 0 0]
[0 0 0 1 0 0 0]]
#Strides (bytes) required to traverse in each axis.
(56, 8)
对于您的数组,要获取 中的下一个元素axis=0
,我们需要遍历56 bytes
,而对于 中的下一个元素axis=1
,我们需要8 bytes
。
当您编制索引(4,2)
时,NumPy 会56*4
输入字节axis=0
和8*2
字节axis=1
来访问它。同样,如果要访问(4,2)
and (5,3)
,则必须访问56*(4,5)
inaxis=0
和8*(2,3)
in axis=1
。
这就是设计之所以如此的原因,因为它与 NumPy 实际使用strides
.
X[(axis0_indices), (axis1_indices), ..]
X[(4, 5), (2, 3)] #(row indices), (column indices)
array([1, 1])
通过这种设计,也很容易扩展到更高维的张量(比如 8 维数组)!如果您分别提到每个索引元组,则需要计算元素*维数才能获取这些。虽然使用这种设计,它可以将步幅值广播到每个轴的元组并更快地获取这些值!