python - RuntimeError: running_mean 应该包含 256 个元素而不是 128 个 pytorch
问题描述
我是 PyTorch、GAN 的新手,对 Python 没有太多经验(虽然我是 C/C++ 程序员)。
我有一个用于生成假图像的简单 DCGAN 教程代码,当我使用“DATASETNAME = 'MNIST'”运行代码时就可以了。但是,当我将数据集更改为“CIFAR10”时,程序会产生与“running_mean”相关的错误。
代码如下
将 torch.nn 导入为 nn
def weights_init(模块):
if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d):
module.weight.detach().normal_(mean=0., std=0.02)
elif isinstance(module, nn.BatchNorm2d):
module.weight.detach().normal_(1., 0.02)
module.bias.detach().zero_()
else:
pass
类视图(nn.Module):
def __init__(self, output_shape):
super(View, self).__init__()
self.output_shape = output_shape
def forward(self, x):
return x.view(x.shape[0], *self.output_shape)
类生成器(nn.Module):
def __init__(self, dataset_name):
super(Generator, self).__init__()
act = nn.ReLU(inplace=True)
norm = nn.BatchNorm2d
if dataset_name == 'CIFAR10': # Output shape 3x32x32
model = [nn.Linear(100, 512 * 4 * 4), View([512, 4, 4]), norm(512), act] # 4x4
model += [nn.ConvTranspose2d(512, 256, 5, stride=2, padding=2, output_padding=1), norm(256), act] # 8x8
model += [nn.ConvTranspose2d(256, 128, 5, stride=2, padding=2, output_padding=1), norm(128), act] # 16x16
model += [nn.ConvTranspose2d(128, 3, 5, stride=2, padding=2, output_padding=1), nn.Tanh()] # 32x32
elif dataset_name == 'LSUN': # Output shape 3x64x64
model = [nn.Linear(100, 1024 * 4 * 4), View([1024, 4, 4]), norm(1024), act] # 4x4
model += [nn.ConvTranspose2d(1024, 512, 5, stride=2, padding=2, output_padding=1), norm(512), act] # 8x8
model += [nn.ConvTranspose2d(512, 256, 5, stride=2, padding=2, output_padding=1), norm(256), act] # 16x16
model += [nn.ConvTranspose2d(256, 128, 5, stride=2, padding=2, output_padding=1), norm(128), act] # 32x32
model += [nn.ConvTranspose2d(128, 3, 5, stride=2, padding=2, output_padding=1), nn.Tanh()] # 64x64
elif dataset_name == 'MNIST': # Output shape 1x28x28
model = [nn.Linear(100, 256 * 4 * 4), View([256, 4, 4]), norm(256), act] # 4x4
model += [nn.ConvTranspose2d(256, 128, 5, stride=2, padding=2), norm(128), act] # 7x7
model += [nn.ConvTranspose2d(128, 64, 5, stride=2, padding=2, output_padding=1), norm(64), act] # 14x14
model += [nn.ConvTranspose2d(64, 1, 5, stride=2, padding=2, output_padding=1), nn.Tanh()] # 28x28
else:
raise NotImplementedError
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
类鉴别器(nn.Module):
def __init__(self, dataset_name):
super(Discriminator, self).__init__()
act = nn.LeakyReLU(inplace=True, negative_slope=0.2)
norm = nn.BatchNorm2d
if dataset_name == 'CIFAR10': # Input shape 3x32x32
model = [nn.Conv2d(3, 128, 5, stride=2, padding=2, bias=False), act] # 16x16
model += [nn.Conv2d(128, 256, 5, stride=2, padding=2, bias=False), norm(128), act] # 8x8
model += [nn.Conv2d(256, 512, 5, stride=2, padding=2, bias=False), norm(256), act] # 4x4
model += [nn.Conv2d(512, 1, 4, stride=2, padding=2, bias=False), nn.Sigmoid()] # 1x1
elif dataset_name == 'LSUN': # Input shape 3x64x64
model = [nn.Conv2d(3, 128, 5, stride=2, padding=2, bias=False), act] # 128x32x32
model += [nn.Conv2d(128, 256, 5, stride=2, padding=2, bias=False), norm(128), act] # 256x16x16
model += [nn.Conv2d(256, 512, 5, stride=2, padding=2, bias=False), norm(256), act] # 512x8x8
model += [nn.Conv2d(512, 1024, 5, stride=2, padding=2, bias=False), norm(512), act] # 1024x4x4
model += [nn.Conv2d(1024, 1, 4), nn.Sigmoid()] # 1x1x1
elif dataset_name == 'MNIST': # Input shape 1x28x28
model = [nn.Conv2d(1, 64, 5, stride=2, padding=2, bias=False), act] # 14x14
model += [nn.Conv2d(64, 128, 5, stride=2, padding=2, bias=False), norm(128), act] # 7x7
model += [nn.Conv2d(128, 256, 5, stride=2, padding=2, bias=False), norm(256), act] # 4x4
model += [nn.Conv2d(256, 1, 4, bias=False), nn.Sigmoid()] # 1x1
else:
raise NotImplementedError
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
如果名称== '主要':
import os
from torchvision.transforms import Compose, Normalize, Resize, ToTensor
from torch.utils.data import DataLoader
#from models import Discriminator, Generator, weights_init
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from time import time
from tqdm import tqdm
from torchvision.utils import save_image
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
BETA1, BETA2 = 0.5, 0.99
BATCH_SIZE = 16
DATASET_NAME = 'CIFAR10'
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu:0')
EPOCHS = 1
ITER_REPORT = 10
LATENT_DIM = 100
LR = 2e-4
N_D_STEP = 1
ITER_DISPLAY = 500
IMAGE_DIR = './GAN/checkpoints/'+DATASET_NAME+'/Image'
MODEL_DIR = './GAN/checkpoints/'+DATASET_NAME+'/Model'
if DATASET_NAME == 'CIFAR10':
IMAGE_SIZE = 32
OUT_CHANNEL = 3
from torchvision.datasets import CIFAR10
transforms = Compose([ToTensor(), Normalize(mean=[0.5], std=[0.5])])
dataset = CIFAR10(root='./datasets', train=True, transform=transforms, download=True)
elif DATASET_NAME == 'LSUN':
IMAGE_SIZE = 64
OUT_CHANNEL = 3
from torchvision.datasets import LSUN
transforms = Compose([Resize(64), ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
dataset = LSUN(root='./datasets/LSUN', classes=['bedroom_train'], transform=transforms)
elif DATASET_NAME == 'MNIST':
IMAGE_SIZE = 28
OUT_CHANNEL = 1
from torchvision.datasets import MNIST
transforms = Compose([ToTensor(), Normalize(mean=[0.5], std=[0.5])])
dataset = MNIST(root='./datasets', train=True, transform=transforms, download=True)
else:
raise NotImplementedError
data_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, num_workers=0, shuffle=True)
D = Discriminator(DATASET_NAME).apply(weights_init).to(DEVICE)
G = Generator(DATASET_NAME).apply(weights_init).to(DEVICE)
print(D, G)
criterion = nn.BCELoss()
optim_D = torch.optim.Adam(D.parameters(), lr=LR, betas=(BETA1, BETA2))
optim_G = torch.optim.Adam(G.parameters(), lr=LR, betas=(BETA1, BETA2))
list_D_loss = list()
list_G_loss = list()
total_step = 0
st = time()
for epoch in range(EPOCHS):
for data in tqdm(data_loader):
total_step += 1
real, label = data[0].to(DEVICE), data[1].to(DEVICE)
z = torch.randn(BATCH_SIZE, LATENT_DIM).to(DEVICE)
fake = G(z)
real_score = D(real)
fake_score = D(fake.detach())
D_loss = 0.5 * (criterion(fake_score, torch.zeros_like(fake_score).to(DEVICE))
+ criterion(real_score, torch.ones_like(real_score).to(DEVICE)))
optim_D.zero_grad()
D_loss.backward()
optim_D.step()
list_D_loss.append(D_loss.detach().cpu().item())
if total_step % ITER_DISPLAY == 0:
#(BatchSize, Channel*ImageSize*ImageSize)-->(BatchSize, Channel, ImageSize, ImageSize)
fake = fake.view(BATCH_SIZE, OUT_CHANNEL, IMAGE_SIZE, IMAGE_SIZE)
real = real.view(BATCH_SIZE, OUT_CHANNEL, IMAGE_SIZE, IMAGE_SIZE)
save_image(fake, IMAGE_DIR + '/{}_fake.png'.format(epoch + 1), nrow=4, normalize=True)
save_image(real, IMAGE_DIR + '/{}_real.png'.format(epoch + 1), nrow=4, normalize=True)
if total_step % N_D_STEP == 0:
fake_score = D(fake)
G_loss = criterion(fake_score, torch.ones_like(fake_score))
optim_G.zero_grad()
G_loss.backward()
optim_G.step()
list_G_loss.append(G_loss.detach().cpu().item())
if total_step % ITER_REPORT == 0:
print("Epoch: {}, D_loss: {:.{prec}} G_loss: {:.{prec}}"
.format(epoch, D_loss.detach().cpu().item(), G_loss.detach().cpu().item(), prec=4))
torch.save(D.state_dict(), '{}_D.pt'.format(DATASET_NAME))
torch.save(G.state_dict(), '{}_G.pt'.format(DATASET_NAME))
plt.figure()
plt.plot(range(0, len(list_D_loss)), list_D_loss, linestyle='--', color='r', label='Discriminator loss')
plt.plot(range(0, len(list_G_loss) * N_D_STEP, N_D_STEP), list_G_loss, linestyle='--', color='g',
label='Generator loss')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.legend()
plt.savefig('Loss.png')
print(time() - st)
该错误似乎来自 Discriminator .forward 如下:
RuntimeError Traceback (most recent call last)
in
71 fake = G(z)
72
?> 73 real_score = D(real)
74 fake_score = D(fake.detach())
75
C:\Anaconda3\lib\site-packages\torch\nn\modules\module.py in call(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
?> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
in forward(self, x)
87
88 def forward(self, x):
?> 89 return self.model(x)
C:\Anaconda3\lib\site-packages\torch\nn\modules\module.py in call(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
?> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
C:\Anaconda3\lib\site-packages\torch\nn\modules\container.py in forward(self, input)
90 def forward(self, input):
91 for module in self._modules.values():
?> 92 input = module(input)
93 return input
94
C:\Anaconda3\lib\site-packages\torch\nn\modules\module.py in call(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
?> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
C:\Anaconda3\lib\site-packages\torch\nn\modules\batchnorm.py in forward(self, input)
81 input, self.running_mean, self.running_var, self.weight, self.bias,
82 self.training or not self.track_running_stats,
?> 83 exponential_average_factor, self.eps)
84
85 def extra_repr(self):
C:\Anaconda3\lib\site-packages\torch\nn\functional.py in batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
1695 return torch.batch_norm(
1696 input, weight, bias, running_mean, running_var,
-> 1697 training, momentum, eps, torch.backends.cudnn.enabled
1698 )
1699
RuntimeError: running_mean should contain 256 elements not 128
谁能告诉我这个错误是什么?它似乎来自模型中某些东西的尺寸设置,但这就是我能猜到的全部。
先感谢您。
解决方案
线
model += [nn.Conv2d(128, 256, 5, stride=2, padding=2, bias=False), norm(128), act] # 8x8
这是一个批量标准化输入错误,应该是 256。
推荐阅读
- java - 是否可以结合杰克逊注解@JacksonXmlText 和@JsonRawValue?
- r - 使用名称提取列表元素
- python - 如何获得 git diff 以便我可以使用它来使用 semver 来提升我的版本?(ThreeDotLabs 教程)
- python-3.x - convert_to_generator_like num_samples 属性错误:“int”对象没有属性“shape”
- python - 为什么 Python http 请求会创建 TIME_WAIT 连接?
- spring - Spring Boot 和 Spring Security 过滤器未过滤正确的请求
- jquery - 如何在主 div 中为多个 div 应用滚动条
- javascript - 斜杠前最后一个破折号后的反应路由器匹配字符串
- sql - 如何构建一个查询,为员工提取员工 ID 列表,这些员工的值发生在员工独有的日期范围内?
- python - 说python中属性的setter函数就像重载赋值运算符是对的吗?