machine-learning - 如何解决尺寸误差并有效使用Conv2dTranspose?
问题描述
我已经创建了一个鉴别器和生成器文件来实现 GAN,但是,我正面临这个错误。
我面临的最初错误是在 main.py 文件中,我在其中调用标准库并传递输出和标签。我使用挤压功能解决了该错误,从而解决了形状问题。
在使用挤压之前,错误显示输出和标签的形状不匹配(输出和标签的形状分别为(7,1,1,1)和(7)。
import torch
from torch import nn
class generatorG(nn.Module):
def __init__(self):
super(generatorG, self).__init__()
self.t1 = nn.Sequential(
nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size= (4,4), stride = 2,padding = 1),
nn.LeakyReLU(0.2, inplace = True)
)
self.t2 = nn.Sequential(
nn.Conv2d(in_channels= 64, out_channels = 64, kernel_size = (4,4), stride = 2,padding = 1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace = True)
)
self.t3 = nn.Sequential(
nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 4, stride = 2, padding =1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace = True)
)
self.t4 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(4, 4), stride = 2,padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True)
)
self.t5 = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True)
)
self.t6 = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=4000, kernel_size=(4, 4)),
nn.BatchNorm2d(4000),
nn.ReLU()
)
self.t7 = nn.Sequential(
nn.ConvTranspose2d(in_channels = 512, out_channels = 256, kernel_size =4, stride = 2, padding = 1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.t8 = nn.Sequential(
nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.t9 = nn.Sequential(
nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.t10 = nn.Sequential(
nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, x):
x = self.t1(x)
x = self.t2(x)
x = self.t3(x)
x = self.t4(x)
x = self.t5(x)
x = self.t6(x)
x = self.t7(x)
x = self.t8(x)
x = self.t9(x)
x = self.t10(x)
return x
model = generatorG()
print(model(torch.randn()).shape)
鉴别器文件
import torch
from torch import nn
class DiscriminatorD(nn.Module):
def __init__(self):
super(DiscriminatorD, self).__init__()
self.t1 = nn.Sequential(
nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size =4, stride = 2, padding = 1),
nn.LeakyReLU(0.2, inplace = True)
)
self.t2 = nn.Sequential(
nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 4, stride = 2, padding = 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace = True)
)
self.t3 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True)
)
self.t4 = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True)
)
self.t5 = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=0),
nn.Sigmoid()
)
def forward(self, x):
x = self.t1(x)
x = self.t2(x)
x = self.t3(x)
x = self.t4(x)
x = self.t5(x)
return x
main.py 文件
from generator import *
from discriminator import *
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import utils
epochs = 100
Batch_Size = 64
lr = 0.0002
beta1 = 0.5
over = 4
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot', default = 'dataset/train', help = 'path to dataset')
opt = parser.parse_args()
try:
os.makedirs('result/train/cropped')
os.makedirs('result/train/real')
os.makedirs('result/train/recon')
os.makedirs('model/')
except:
pass
transform = transforms.Compose([transforms.Scale(128),
transforms.CenterCrop(128),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
dataset = dset.ImageFolder(root=opt.dataroot, transform= transform)
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=Batch_Size, shuffle=True, num_workers=0)
wtl2 = 0.999
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv')!=-1:
m.weight.data.normal_(0.0,0.2)
elif classname.find('BatchNorm')!=-1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
resume_epoch = 0
netG = generatorG()
netG.apply(weights_init)
netD = DiscriminatorD()
netD.apply(weights_init)
criterion = nn.BCELoss()
criterionMSE = nn.MSELoss()
input_real = torch.FloatTensor(Batch_Size, 3, 128, 128)
input_cropped = torch.FloatTensor(Batch_Size, 3, 128, 128)
label = torch.FloatTensor(Batch_Size)
real_label = 1
fake_label = 0
real_center = torch.FloatTensor(Batch_Size, 3, 64, 64)
input_real = Variable(input_real)
input_cropped = Variable(input_cropped)
label = Variable(label)
real_center = Variable(real_center)
optimizerD = optim.Adam(netD.parameters(), lr = lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr = lr, betas = (beta1, 0.999))
over = 4
for epoch in range(resume_epoch, epochs):
for i, data in enumerate(dataloader, 0):
real_cpu, _ = data
real_center_cpu = real_cpu[:,:,int(128/4):int(128/4)+int(128/2),int(128/4):int(128/4)+int(128/2)]
batch_size = real_cpu.size(0)
with torch.no_grad():
input_real.resize_(real_cpu.size()).copy_(real_cpu)
input_cropped.resize_(real_cpu.size()).copy_(real_cpu)
real_center.resize_(real_center_cpu.size()).copy_(real_center_cpu)
input_cropped[:, 0, int(128 / 4 + over):int(128 / 4 + 128 / 2 - over),int(128 / 4 + over):int(128 / 4 + 128 / 2 - over)] = 2 * 117.0 / 255.0 - 1.0
input_cropped[:, 1, int(128 / 4 + over):int(128 / 4 + 128 / 2 - over),int(128 / 4 + over):int(128 / 4 + 128 / 2 - over)] = 2 * 104.0 / 255.0 - 1.0
input_cropped[:, 2, int(128 / 4 + over):int(128 / 4 + 128 / 2 - over),int(128 / 4 + over):int(128 / 4 + 128 / 2 - over)] = 2 * 123.0 / 255.0 - 1.0
netD.zero_grad()
with torch.no_grad():
label.resize_(batch_size).fill_(real_label)
output = netD(real_center)
# output = torch.unsqueeze(output[0, 1)
output = torch.squeeze(output, 1)
output = torch.squeeze(output, 1)
output = torch.squeeze(output, 1)
print(output.shape)
# label = label.unsqueeze(1)
# label = label.unsqueeze(1)
# label = label.unsqueeze(1)
print(label.shape)
errD_real = criterion(output, label)
errD_real.backward()
D_x = output.data.mean()
print(input_cropped.shape)
fake = netG(input_cropped)
label.data.fill_(fake_label)
output = netD(fake.detach())
errD_fake = criterion(output, label)
errD_fake.backward()
D_G_z1 = output.data.mean()
errD = errD_real + errD_fake
optimizerD.step()
netG.zero_grad()
label.data.fill_(real_label) # fake labels are real for generator cost
output = netD(fake)
errG_D = criterion(output, label)
wtl2Matrix = real_center.clone()
wtl2Matrix.data.fill_(wtl2 * 10)
wtl2Matrix.data[:, :, int(over):int(128 / 2 - over), int(over):int(128 / 2 - over)] = wtl2
errG_l2 = (fake - real_center).pow(2)
errG_l2 = errG_l2 * wtl2Matrix
errG_l2 = errG_l2.mean()
errG = (1 - wtl2) * errG_D + wtl2 * errG_l2
errG.backward()
D_G_z2 = output.data.mean()
optimizerG.step()
print('[%d / %d][%d / %d] Loss_D: %.4f Loss_G: %.4f / %.4f l_D(x): %.4f l_D(G(z)): %.4f'
% (epoch, epochs, i, len(dataloader),
errD.data, errG_D.data, errG_l2.data, D_x, D_G_z1,))
if i % 100 == 0:
vutils.save_image(real_cpu,
'result/train/real/real_samples_epoch_%03d.png' % (epoch))
vutils.save_image(input_cropped.data,
'result/train/cropped/cropped_samples_epoch_%03d.png' % (epoch))
recon_image = input_cropped.clone()
recon_image.data[:, :, int(128 / 4):int(128 / 4 + 128 / 2), int(128 / 4):int(128 / 4 + 128 / 2)] = fake.data
vutils.save_image(recon_image.data,
'result/train/recon/recon_center_samples_epoch_%03d.png' % (epoch))
实用程序文件
import torch
from PIL import Image
from torch.autograd import Variable
def load_image(filename, size = None, scale = None):
img = Image.open(filename)
if size is not None:
img = img.resize((size, size), Image.ANTIALIAS)
elif scale is not None:
img = img.resize((int(img.size[0]/scale), int(img.size[1]/scale)), Image.ANTIALIAS)
return img
def save_image(filename, data):
img = data.clone().add(1).div(2).mul(255).clamp(0,255).numpy()
img = img.transpose(1,2,0).astype('uint8')
img = Image.fromarray(img)
img.save(filename)
def gram_matrix(y):
(b, ch, h, w) = y.size()
features = y.view(b, ch, w*h)
features_t = features.transpose(1,2)
gram = features.bmm(features_t)/(ch*h*w)
return gram
def normalize_batch(batch):
mean = batch.data.new(batch.data.size())
std = batch.data.new(batch.data.size())
mean[:, 0, :, :] = 0.485
mean[:, 1, :, :] = 0.456
mean[:, 2, :, :] = 0.406
std[:, 0, :, :] = 0.229
std[:, 1, :, :] = 0.224
std[:, 2, :, :] = 0.225
batch = torch.div(batch, 255.0)
batch -= Variable(mean)
# batch /= Variable(std)
batch = torch.div(batch, Variable(std))
return batch
错误信息
(impaint_env) vivek@Viveks-MacBook-Pro image_impainter % python main.py
/Users/vivek/DSwork/image_impainter/impaint_env/lib/python3.8/site-packages/torchvision/transforms/transforms.py:310: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
warnings.warn("The use of the transforms.Scale transform is deprecated, " +
torch.Size([7])
torch.Size([7])
torch.Size([7, 3, 128, 128])
Traceback (most recent call last):
File "main.py", line 114, in <module>
fake = netG(input_cropped)
File "/Users/vivek/DSwork/image_impainter/impaint_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/vivek/DSwork/image_impainter/generator.py", line 70, in forward
x = self.t7(x)
File "/Users/vivek/DSwork/image_impainter/impaint_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/vivek/DSwork/image_impainter/impaint_env/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
input = module(input)
File "/Users/vivek/DSwork/image_impainter/impaint_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/vivek/DSwork/image_impainter/impaint_env/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 916, in forward
return F.conv_transpose2d(
RuntimeError: Given transposed=1, weight of size [512, 256, 4, 4], expected input[7, 4000, 1, 1] to have 512 channels, but got 4000 channels instead
解决方案
t6
你在 layer和t7
你的之间有一个“差距” generatorG
:
# ...
self.t6 = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=4000, kernel_size=(4, 4)),
nn.BatchNorm2d(4000),
nn.ReLU()
)
self.t7 = nn.Sequential(
nn.ConvTranspose2d(in_channels = 512, out_channels = 256, kernel_size =4, stride = 2, padding = 1),
nn.BatchNorm2d(256),
nn.ReLU()
)
# ...
您的t6
层期望输入具有 512 个通道并输出具有 4000 个通道的张量。然而,下一层,t7
预计输入只有 512 个通道。
您需要进行调整t6
,t7
以便t6
输出与预期完全相同的通道数。t7
那就是t6
' out_channles
,只是相等t7
' in_channels
。
推荐阅读
- node.js - redirect_uri_mismatch: http://localhost:4500/AUTH/auth/google/callback,不匹配
- swift - 条件绑定的初始化程序必须具有 Optional 类型,而不是 Swift 5 中的 CustomClass
- java - 如何使用 LocalDateTime 查询今天的所有实体?
- c# - 如何防止同一类型表单的对象同时添加到列表中
- reactjs - 这是呈现组件的标准方式吗?
- python-3.x - Boto3 担任 MFA 的跨账户角色
- android - 我使用 hupu 应用程序时的 Android 11 文件共享错误
- html - 桌面:div 并排包含 3 个元素
- python - I got error while getting all the list of members from guild
- xml - 读取 XML 文件元素和之间的值