python - 将 4 维张量转换为列表列表(Python)
问题描述
我有 6 个形状的张量 (batch_size, S, S, 1),我想将它们组合成一个大小为 (batch_size, S*S, 6) 的 python 列表 - 所以张量的每个元素都应该在内部列表中。
这可以在不使用循环的情况下实现吗?解决它的有效方法是什么?
解决方案
让batch_size=10
和S=4
为了这个例子的目的:
>>> x = [torch.rand(10, 4, 4, 1) for _ in range(6)]
实际上,第一步是在最后一个维度上连接张量axis=3
:
>>> y = torch.cat(x, -1)
>>> y.shape
torch.Size([10, 4, 4, 6])
然后重塑以展平axis=1
和axis=2
,您可以在torch.flatten
此处执行此操作,因为两个轴相邻:
>>> y = torch.cat(x, -1).flatten(1, 2)
>>> y.shape
torch.Size([10, 16, 6])