首页 > 解决方案 > 如何解决尺寸误差并有效使用Conv2dTranspose?

问题描述

我已经创建了一个鉴别器和生成器文件来实现 GAN,但是,我正面临这个错误。

我面临的最初错误是在 main.py 文件中,我在其中调用标准库并传递输出和标签。我使用挤压功能解决了该错误,从而解决了形状问题。

在使用挤压之前,错误显示输出和标签的形状不匹配(输出和标签的形状分别为(7,1,1,1)和(7)。

import torch
from torch import nn
class generatorG(nn.Module):
    def __init__(self):
        super(generatorG, self).__init__()

        self.t1 = nn.Sequential(
            nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size= (4,4), stride = 2,padding = 1),
            nn.LeakyReLU(0.2, inplace = True)
        )

        self.t2 = nn.Sequential(
            nn.Conv2d(in_channels= 64, out_channels = 64, kernel_size = (4,4), stride = 2,padding = 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace = True)
        )

        self.t3 = nn.Sequential(
            nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 4, stride = 2, padding =1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace = True)
        )
        self.t4 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(4, 4), stride = 2,padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.t5 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.t6 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=4000, kernel_size=(4, 4)),
            nn.BatchNorm2d(4000),
            nn.ReLU()
        )

        self.t7 = nn.Sequential(
            nn.ConvTranspose2d(in_channels = 512, out_channels = 256, kernel_size =4, stride = 2, padding = 1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        self.t8 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )

        self.t9 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )

        self.t10 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.t1(x)
        x = self.t2(x)
        x = self.t3(x)
        x = self.t4(x)
        x = self.t5(x)
        x = self.t6(x)
        x = self.t7(x)
        x = self.t8(x)
        x = self.t9(x)
        x = self.t10(x)

        return x


model = generatorG()
print(model(torch.randn()).shape)

鉴别器文件

import torch
from torch import nn

class DiscriminatorD(nn.Module):
    def __init__(self):
        super(DiscriminatorD, self).__init__()

        self.t1 = nn.Sequential(
            nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size =4, stride = 2, padding = 1),
            nn.LeakyReLU(0.2, inplace = True)
        )

        self.t2 = nn.Sequential(
            nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 4, stride = 2, padding = 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace = True)
        )

        self.t3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.t4 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.t5 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.t1(x)
        x = self.t2(x)
        x = self.t3(x)
        x = self.t4(x)
        x = self.t5(x)

        return x

main.py 文件

from generator import *
from discriminator import *
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable

import utils

epochs = 100
Batch_Size = 64
lr = 0.0002
beta1 = 0.5
over = 4
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot', default = 'dataset/train', help = 'path to dataset')
opt = parser.parse_args()

try:
    os.makedirs('result/train/cropped')
    os.makedirs('result/train/real')
    os.makedirs('result/train/recon')
    os.makedirs('model/')

except:
    pass

transform = transforms.Compose([transforms.Scale(128),
                                transforms.CenterCrop(128),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
dataset = dset.ImageFolder(root=opt.dataroot, transform= transform)
assert dataset

dataloader = torch.utils.data.DataLoader(dataset, batch_size=Batch_Size, shuffle=True, num_workers=0)
wtl2 = 0.999

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv')!=-1:
        m.weight.data.normal_(0.0,0.2)
    elif classname.find('BatchNorm')!=-1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

resume_epoch = 0
netG = generatorG()
netG.apply(weights_init)

netD = DiscriminatorD()
netD.apply(weights_init)

criterion = nn.BCELoss()
criterionMSE = nn.MSELoss()

input_real = torch.FloatTensor(Batch_Size, 3, 128, 128)
input_cropped = torch.FloatTensor(Batch_Size, 3, 128, 128)
label = torch.FloatTensor(Batch_Size)
real_label = 1
fake_label = 0

real_center = torch.FloatTensor(Batch_Size, 3, 64, 64)

input_real = Variable(input_real)
input_cropped = Variable(input_cropped)
label = Variable(label)

real_center = Variable(real_center)

optimizerD = optim.Adam(netD.parameters(), lr = lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr = lr, betas = (beta1, 0.999))
over = 4
for epoch in range(resume_epoch, epochs):
    for i, data in enumerate(dataloader, 0):
        real_cpu, _ = data
        real_center_cpu = real_cpu[:,:,int(128/4):int(128/4)+int(128/2),int(128/4):int(128/4)+int(128/2)]
        batch_size = real_cpu.size(0)

        with torch.no_grad():
            input_real.resize_(real_cpu.size()).copy_(real_cpu)
            input_cropped.resize_(real_cpu.size()).copy_(real_cpu)
            real_center.resize_(real_center_cpu.size()).copy_(real_center_cpu)
            input_cropped[:, 0, int(128 / 4 + over):int(128 / 4 + 128 / 2 - over),int(128 / 4 + over):int(128 / 4 + 128 / 2 - over)] = 2 * 117.0 / 255.0 - 1.0
            input_cropped[:, 1, int(128 / 4 + over):int(128 / 4 + 128 / 2 - over),int(128 / 4 + over):int(128 / 4 + 128 / 2 - over)] = 2 * 104.0 / 255.0 - 1.0
            input_cropped[:, 2, int(128 / 4 + over):int(128 / 4 + 128 / 2 - over),int(128 / 4 + over):int(128 / 4 + 128 / 2 - over)] = 2 * 123.0 / 255.0 - 1.0

        netD.zero_grad()
        with torch.no_grad():
            label.resize_(batch_size).fill_(real_label)

        output = netD(real_center)
        # output = torch.unsqueeze(output[0, 1)
        output = torch.squeeze(output, 1)
        output = torch.squeeze(output, 1)
        output = torch.squeeze(output, 1)
        print(output.shape)
        # label = label.unsqueeze(1)
        # label = label.unsqueeze(1)
        # label = label.unsqueeze(1)

        print(label.shape)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.data.mean()
        print(input_cropped.shape)
        fake = netG(input_cropped)
        label.data.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.data.mean()
        errD = errD_real + errD_fake
        optimizerD.step()

        netG.zero_grad()
        label.data.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        errG_D = criterion(output, label)

        wtl2Matrix = real_center.clone()
        wtl2Matrix.data.fill_(wtl2 * 10)
        wtl2Matrix.data[:, :, int(over):int(128 / 2 - over), int(over):int(128 / 2 - over)] = wtl2

        errG_l2 = (fake - real_center).pow(2)
        errG_l2 = errG_l2 * wtl2Matrix
        errG_l2 = errG_l2.mean()

        errG = (1 - wtl2) * errG_D + wtl2 * errG_l2

        errG.backward()
        D_G_z2 = output.data.mean()
        optimizerG.step()

        print('[%d / %d][%d / %d] Loss_D: %.4f Loss_G: %.4f / %.4f l_D(x): %.4f l_D(G(z)): %.4f'
              % (epoch, epochs, i, len(dataloader),
                 errD.data, errG_D.data, errG_l2.data, D_x, D_G_z1,))

        if i % 100 == 0:
            vutils.save_image(real_cpu,
                              'result/train/real/real_samples_epoch_%03d.png' % (epoch))
            vutils.save_image(input_cropped.data,
                              'result/train/cropped/cropped_samples_epoch_%03d.png' % (epoch))
            recon_image = input_cropped.clone()
            recon_image.data[:, :, int(128 / 4):int(128 / 4 + 128 / 2), int(128 / 4):int(128 / 4 + 128 / 2)] = fake.data
            vutils.save_image(recon_image.data,
                              'result/train/recon/recon_center_samples_epoch_%03d.png' % (epoch))

实用程序文件

import torch
from PIL import Image
from torch.autograd import Variable

def load_image(filename, size = None, scale = None):
    img = Image.open(filename)
    if size is not None:
        img = img.resize((size, size), Image.ANTIALIAS)
    elif scale is not None:
        img = img.resize((int(img.size[0]/scale), int(img.size[1]/scale)), Image.ANTIALIAS)

    return img

def save_image(filename, data):
    img = data.clone().add(1).div(2).mul(255).clamp(0,255).numpy()
    img = img.transpose(1,2,0).astype('uint8')
    img = Image.fromarray(img)
    img.save(filename)

def gram_matrix(y):
    (b, ch, h, w) = y.size()
    features = y.view(b, ch, w*h)
    features_t = features.transpose(1,2)
    gram = features.bmm(features_t)/(ch*h*w)
    return gram

def normalize_batch(batch):
    mean = batch.data.new(batch.data.size())
    std = batch.data.new(batch.data.size())
    mean[:, 0, :, :] = 0.485
    mean[:, 1, :, :] = 0.456
    mean[:, 2, :, :] = 0.406
    std[:, 0, :, :] = 0.229
    std[:, 1, :, :] = 0.224
    std[:, 2, :, :] = 0.225
    batch = torch.div(batch, 255.0)
    batch -= Variable(mean)
    # batch /= Variable(std)
    batch = torch.div(batch, Variable(std))
    return batch

错误信息

(impaint_env) vivek@Viveks-MacBook-Pro image_impainter % python main.py
/Users/vivek/DSwork/image_impainter/impaint_env/lib/python3.8/site-packages/torchvision/transforms/transforms.py:310: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
  warnings.warn("The use of the transforms.Scale transform is deprecated, " +
torch.Size([7])
torch.Size([7])
torch.Size([7, 3, 128, 128])
Traceback (most recent call last):
  File "main.py", line 114, in <module>
    fake = netG(input_cropped)
  File "/Users/vivek/DSwork/image_impainter/impaint_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/vivek/DSwork/image_impainter/generator.py", line 70, in forward
    x = self.t7(x)
  File "/Users/vivek/DSwork/image_impainter/impaint_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/vivek/DSwork/image_impainter/impaint_env/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/Users/vivek/DSwork/image_impainter/impaint_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/vivek/DSwork/image_impainter/impaint_env/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 916, in forward
    return F.conv_transpose2d(
RuntimeError: Given transposed=1, weight of size [512, 256, 4, 4], expected input[7, 4000, 1, 1] to have 512 channels, but got 4000 channels instead

标签: machine-learningneural-networkpytorchconv-neural-networkgenerative-adversarial-network

解决方案


t6你在 layer和t7你的之间有一个“差距” generatorG

        # ...
        self.t6 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=4000, kernel_size=(4, 4)),
            nn.BatchNorm2d(4000),
            nn.ReLU()
        )

        self.t7 = nn.Sequential(
            nn.ConvTranspose2d(in_channels = 512, out_channels = 256, kernel_size =4, stride = 2, padding = 1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        # ...

您的t6层期望输入具有 512 个通道并输出具有 4000 个通道的张量。然而,下一层,t7预计输入只有 512 个通道。

您需要进行调整t6t7以便t6输出与预期完全相同的通道数。t7那就是t6' out_channles,只是相等t7' in_channels


推荐阅读