首页 > 解决方案 > 如何在 PyTorch 中重塑。[1, 257, 512] -> [1, 512, 16,16]

问题描述

我目前正在使用变压器进行图像生成。

我使用 Vit 部分,因为它用于编码器部分。

另外,我想附加一个transformer解码器并传递encoder输出,并将transformer解码器的输出放入CNN解码器中以创建图像。

和:

image size = 128 * 128
patch_size = 8
d_model = 512

变压器解码器的输出为[1, 257, 512]

[1,257,512]=>hw/64 * 512我想要重塑h/8 * w/8 * 512,但我不知道如何重塑它。

我怎样才能257变身256

我用

decoder_out = decoder_out.permute(0, 2, 1).view(1, self.d_model, 16, 16)

标签: pythonpytorchtransformer

解决方案


推荐阅读