python - 张量处的 runtimeerror 大小不匹配
问题描述
错误消息:RuntimeError: size mismatch, m1: [64 x 3200], m2: [512 x 1] at C:/w/1/s/windows/pytorch/aten/src\THC/generic/THCTensorMathBlas.cu:290
代码如下:
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim)
self.init_size = opt.img_size // 4 # Initial size before upsampling
self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
nn.Tanh(),
)
def forward(self, noise, labels):
gen_input = torch.mul(self.label_emb(labels), noise)
out = self.l1(gen_input)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, bn=True):
"""Returns layers of each discriminator block"""
block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
return block
self.conv_blocks = nn.Sequential(
*discriminator_block(opt.channels, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
)
# The height and width of downsampled image
ds_size = opt.img_size // 2 ** 4
# Output layers
self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax())
def forward(self, img):
out = self.conv_blocks(img)
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)
label = self.aux_layer(out)
return validity, label
# Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
os.makedirs("../../data/mnist", exist_ok=True)
labels_path = 'C:/project/PyTorch-GAN/ulna/train-labels-idx1-ubyte.gz'
images_path = 'C:/project/PyTorch-GAN/ulna/train-images-idx3-ubyte.gz'
label_name = []
with gzip.open(labels_path, 'rb') as lbpath:
labels = np.frombuffer(lbpath.read(), dtype="int32", offset=8)
with gzip.open(images_path, 'rb') as imgpath:
images = np.frombuffer(imgpath.read(), dtype="int32", offset=16).reshape(len(labels),70,70,1)
hand_transform2 = transforms.Compose([
transforms.Resize((70, 70)),
transforms.Grayscale(1),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
#images=cv2.resize(images, (70, 70),1)
dataset1 = datasets.ImageFolder('C:/project/PyTorch-GAN/ulna/ulna', transform=hand_transform2)
dataloader = torch.utils.data.DataLoader(
dataset1,
batch_size=opt.batch_size,
shuffle=True,
)
Traceback 如下:
Traceback (most recent call last):
File "acgan.py", line 225, in <module>
real_pred, real_aux = discriminator(real_imgs)
File "C:\Users\S\AppData\Local\conda\conda\envs\venv\lib\site-packages\torch\nn\modules\module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "acgan.py", line 110, in forward
validity = self.adv_layer(out)
File "C:\Users\S\AppData\Local\conda\conda\envs\venv\lib\site-packages\torch\nn\modules\module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "C:\Users\S\AppData\Local\conda\conda\envs\venv\lib\site-packages\torch\nn\modules\container.py", line 92, in forward
input = module(input)
File "C:\Users\S\AppData\Local\conda\conda\envs\venv\lib\site-packages\torch\nn\modules\module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "C:\Users\S\AppData\Local\conda\conda\envs\venv\lib\site-packages\torch\nn\modules\linear.py", line 87, in forward
return F.linear(input, self.weight, self.bias)
File "C:\Users\S\AppData\Local\conda\conda\envs\venv\lib\site-packages\torch\nn\functional.py", line 1370, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: size mismatch, m1: [64 x 3200], m2: [512 x 1] at C:/w/1/s/windows/pytorch/aten/src\THC/generic/THCTensorMathBlas.cu:290
我要练习的是GAN代码。修改前的整个 GAN 代码可以在以下链接找到:https ://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/acgan/acgan.py 输入图像是调整为 70x70 的 X 射线图像, 输出图像是通过学习输入 X 射线图像新创建的假 X 射线图像。该代码在使用 minist 数据库时运行良好。恐怕我对代码问题一无所知。请帮我!谢谢你。
解决方案
似乎它opt.img_size
可能仍设置为 32,就像您使用 CIFAR 一样。当您将其调整为 70 时,它应该设置为 70。
无论如何,会出现另一个问题,因为ds_size = opt.img_size // 2 ** 4
对opt.img_size=70
. 如果您想要硬编码解决方案,请设置ds_size=5
. 这修复了鉴别器,但同样的事情也会发生在生成器上。
如果您不了解如何正确解决此问题,我建议您花一些时间阅读这些模型的工作原理。如果您想按原样使用代码,我建议您使用img_size
16 的倍数,例如,opt.img_size=80
您不会有任何问题。为避免其他问题,您可能希望使用transforms.Resize((opt.img_size, opt.img_size))
而不是硬编码img_size
那里。
推荐阅读
- reason - 如何使用 bs-deriving 覆盖/提供自定义实例
- r - R中的感知器不收敛
- sql - 关键字“where”附近的语法不正确。无法理解这个错误
- python - 如何将 MySQL 结果检索到可以按索引搜索的字典中?(在 Python 中)
- python-3.x - np.where 的语法替换列值
- c++ - 使用 GetDC 方法而不是其他方法是否有性能改进?
- javascript - Javascript - 向键添加多个值
- visual-studio - Visual Studio 中的“符号”一词是什么意思?
- html - 单击并拖动标签时,自定义复选框不会切换 [已解决]
- html - CSS动画在径向渐变背景下不平滑