python - 沿特定维度和特定通道索引整个张量
问题描述
假设我们有一个尺寸为dim(A)=[i, j, k=6, u, v]的张量A。现在我们有兴趣通过channels=[0:3]获得维度k的整个张量。我知道我们可以这样得到:
B = A[:, :, 0:3, :, :]
现在我想知道是否有任何更好的“pythonic”方式来实现相同的结果,而无需进行这种次优索引。我的意思是类似的东西。
B = subset(A, dim=2, index=[0, 1, 2])
无论在哪个框架下,即pytorch、tensorflow、numpy等。
非常感谢
解决方案
在numpy中,您可以使用以下take
方法:
B = A.take([0,1,2], axis=2)
在 TensorFlow 中,没有比使用传统方法更简洁的方法了。使用tf.slice
会非常冗长:
B = tf.slice(A,[0,0,0,0,0],[-1,-1,3,-1,-1])
您可能会使用take
(自 TF 2.4 起)的实验版本:
B = tf.experimental.numpy.take(A, [0,1,2], axis=2)
在 PyTorch 中,您可以使用index_select
:
torch.index_select(A, dim=2, index=torch.tensor([0,1,2]))
请注意,您可以使用 : 明确地跳过列出第一个维度(或最后一个维度)ellipsis
:
# Both are equivalent in that case
B = A[..., 0:3, :, :]
B = A[:, :, 0:3, ...]
推荐阅读
- opencv - 三角剖分 3D 点导致扭曲(抛物面)图的原因可能是什么?尝试使用 SFM 执行 3D 重建
- python - 如何解决此问题“ModuleNotFoundError: No module named 'torch.tensor'”
- python - 基于列标识符的聚合计数
- c++ - Cmake - 为不同的源选择不同的c++标准
- algorithm - 递归关系树方法紧界
- python - 我正在创建一个(非常粗糙的)计算器,并想知道这段代码有什么问题
- linux - 虚拟机中的 npm 命令
- flutter - 用空值或空值初始化有什么区别吗?
- julia - 在约束 JuMP 中编码数组
- c++ - 为什么 C++ 允许在 case 标签下定义对象,即使这些 case 与检查值不匹配,这与 if 语句相反?