首页 > 解决方案 > 从其批处理版本重建火炬张量

问题描述

这是如何构建 3D 张量的非常好的示例:

import torch
y = torch.rand(100, 1)
batch_size = 10
batched_data = y.contiguous().view(batch_size, -1, y.size(-1)).transpose(0,1)
batched_data.shape

输出是:

torch.Size([10, 10, 1])

好的,现在我要做的是,从我要构建的 batched_data 开始。另一种方式。对强大的 pytorch 精简代码有什么好的建议吗?

==== 附加输入 =====

我将它用于 RNN,现在我有一些疑问,因为如果您考虑以下代码:

import torch
y = torch.arange(100).view(100,1)
batch_size = 10
batched_data = y.contiguous().view(batch_size, -1, y.size(-1)).transpose(0,1)
batched_data.shape

输出是:

tensor([[[ 0],
         [10],
         [20],
         [30],
         [40],
         [50],
         [60],
         [70],
         [80],
         [90]],

        [[ 1],
         [11],
         [21],
         [31],
         [41],
         [51],
         [61],
         [71],
         [81],
         [91]],

这是我没想到的。我希望是这样的: [[1,2,3,4,5,6,7,8,9,10],[11,12,13,14,15,16,17,18,19,20],....

标签: numpymachine-learningpytorchfeature-extractiontorch

解决方案


假设你想做这样的事情来重建 y:

rebuilded_y = batched_data.transpose(0,1).view(*y.shape)

要使输入看起来像您预期的那样,您需要删除 batched_data 中的转置和附加维度:

batched_data = y.contiguous().view(batch_size, -1)

推荐阅读