python - 索引时如何解包元组?
问题描述
我有一个非常高维的张量,比如形状为 5 X 10 X 100 X 200 X 50 的 A。我有一个返回元组 T 的 numpy 表达式,其中包含我想从 A 中提取的元素的索引。
我正在尝试这个:
A[*T]
它说:
语法无效,您不能在此处使用星号表达式。
我该怎么做?PS:长解是:A[T[0], T[1], T[2], T[3], T[4]]
编辑:我刚刚发现没有必要这样做,因为它是自动完成的。例子:
a= np.random.rand(3,3)
a[np.triu_indices(3)]
当作为索引传递给表达式np.triu_indices(3)
时,表达式会自动解包。但是,回到我的问题并没有发生。具体来说,这里有一个例子:a
a = np.random.rand(100, 50, 14, 14)
a[:, :, np.triu_indices(14)].shape
假设最后一位np.triu_indices(14)
应该作用在最后两个轴上,就像前面的例子一样,但它没有发生,而且结果的形状很奇怪。为什么不拆包?以及如何做到这一点?
解决方案
问题在于:
a[:, :, np.triu_indices(14)]
是您将[...]
混合类型slice
和tuple
( tuple(slice, slice, tuple(np.ndarray, np.ndarray))
) 的元组用作参数,而不是单个元组tuple
(最终使用高级索引),例如tuple(slice, slice, np.ndarray, np.ndarray)
. 这给你带来了麻烦。我不会详细说明您的情况。
将该行更改为:
a[(slice(None),) * 2 + np.triu_indices(14)]
将解决您的问题:
a[(slice(None),) * 2 + np.triu_indices(14)].shape
# (100, 50, 105)
请注意,有几种重写方法:
(slice(None),) * 2 + np.triu_indices(14)
另一种方式可能是:
(slice(None), slice(None), *np.triu_indices(14))