首页 > 解决方案 > 将 4 维张量转换为列表列表(Python)

问题描述

我有 6 个形状的张量 (batch_size, S, S, 1),我想将它们组合成一个大小为 (batch_size, S*S, 6) 的 python 列表 - 所以张量的每个元素都应该在内部列表中。

这可以在不使用循环的情况下实现吗?解决它的有效方法是什么?

标签: pythonlistpytorchtensor

解决方案


batch_size=10S=4为了这个例子的目的:

 >>> x = [torch.rand(10, 4, 4, 1) for _ in range(6)]

实际上,第一步是在最后一个维度上连接张量axis=3

>>> y = torch.cat(x, -1)
>>> y.shape
torch.Size([10, 4, 4, 6])

然后重塑以展平axis=1axis=2,您可以在torch.flatten此处执行此操作,因为两个轴相邻:

>>> y = torch.cat(x, -1).flatten(1, 2)
>>> y.shape
torch.Size([10, 16, 6])

推荐阅读