首页 > 解决方案 > 在 PyTorch 中使用 module.to() 移动成员张量

问题描述

我正在 PyTorch 中构建变分自动编码器 (VAE),但在编写与设备无关的代码时遇到问题。Autoencoder 是nn.Module具有编码器和解码器网络的子网络,它们也是。网络的所有权重都可以通过调用从一个设备移动到另一个设备net.to(device)

我遇到的问题是重新参数化技巧:

encoding = mu + noise * sigma

噪声是与 和 大小相同的张量,musigma保存为自动编码器模块的成员变量。它在构造函数中初始化,并在每个训练步骤就地重新采样。我这样做是为了避免每一步都构建一个新的噪声张量并将其推送到所需的设备。此外,我想修复评估中的噪音。这是代码:

class VariationalGenerator(nn.Module):
    def __init__(self, input_nc, output_nc):
        super(VariationalGenerator, self).__init__()

        self.input_nc = input_nc
        self.output_nc = output_nc
        embedding_size = 128

        self._train_noise = torch.randn(batch_size, embedding_size)
        self._eval_noise = torch.randn(1, embedding_size)
        self.noise = self._train_noise

        # Create encoder
        self.encoder = Encoder(input_nc, embedding_size)
        # Create decoder
        self.decoder = Decoder(output_nc, embedding_size)

    def train(self, mode=True):
        super(VariationalGenerator, self).train(mode)
        self.noise = self._train_noise

    def eval(self):
        super(VariationalGenerator, self).eval()
        self.noise = self._eval_noise

    def forward(self, inputs):
        # Calculate parameters of embedding space
        mu, log_sigma = self.encoder.forward(inputs)
        # Resample noise if training
        if self.training:
            self.noise.normal_()
        # Reparametrize noise to embedding space
        inputs = mu + self.noise * torch.exp(0.5 * log_sigma)
        # Decode to image
        inputs = self.decoder(inputs)

        return inputs, mu, log_sigma

当我现在将自动编码器移动到 GPU 时,net.to('cuda:0')由于噪声张量没有移动,我在转发时遇到错误。

我不想在构造函数中添加设备参数,因为以后仍然无法将其移动到另一个设备。我还尝试将噪声包装到nn.Parameter中,使其受 影响net.to(),但这会导致优化器出错,因为噪声被标记为requires_grad=False

任何人都有解决方案来移动所有模块net.to()

标签: pythondeep-learninggpupytorchautoencoder

解决方案


经过反复试验,我发现了两种方法:

  1. 使用缓冲区:通过替换self._train_noise = torch.randn(batch_size, embedding_size)噪声self.register_buffer('_train_noise', torch.randn(batch_size, embedding_size)张量作为缓冲区添加到模块中。这net.to(device)也会影响它。此外,张量现在是 state_dict 的一部分。
  2. Override net.to(device):使用它,噪音不会出现在 state_dict 之外。

    def to(device):
        new_self = super(VariationalGenerator, self).to(device)
        new_self._train_noise = new_self._train_noise.to(device)
        new_self._eval_noise = new_self._eval_noise.to(device)
    
        return new_self
    

推荐阅读