首页 > 解决方案 > 索引时如何解包元组?

问题描述

我有一个非常高维的张量,比如形状为 5 X 10 X 100 X 200 X 50 的 A。我有一个返回元组 T 的 numpy 表达式,其中包含我想从 A 中提取的元素的索引。

我正在尝试这个:

A[*T]

它说:

语法无效,您不能在此处使用星号表达式。

我该怎么做?PS:长解是:A[T[0], T[1], T[2], T[3], T[4]]

编辑:我刚刚发现没有必要这样做,因为它是自动完成的。例子:

a= np.random.rand(3,3)
a[np.triu_indices(3)]

当作为索引传递给表达式np.triu_indices(3)时,表达式会自动解包。但是,回到我的问题并没有发生。具体来说,这里有一个例子:a

a = np.random.rand(100, 50, 14, 14)
a[:, :, np.triu_indices(14)].shape

假设最后一位np.triu_indices(14)应该作用在最后两个轴上,就像前面的例子一样,但它没有发生,而且结果的形状很奇怪。为什么不拆包?以及如何做到这一点?

标签: pythonnumpy

解决方案


问题在于:

a[:, :, np.triu_indices(14)]

是您将[...]混合类型slicetuple( tuple(slice, slice, tuple(np.ndarray, np.ndarray))) 的元组用作参数,而不是单个元组tuple(最终使用高级索引),例如tuple(slice, slice, np.ndarray, np.ndarray). 这给你带来了麻烦。我不会详细说明您的情况。

将该行更改为:

a[(slice(None),) * 2 + np.triu_indices(14)]

将解决您的问题:

a[(slice(None),) * 2 + np.triu_indices(14)].shape
# (100, 50, 105)

请注意,有几种重写方法:

(slice(None),) * 2 + np.triu_indices(14)

另一种方式可能是:

(slice(None), slice(None), *np.triu_indices(14))

推荐阅读