pytorch - 函数 AddmmBackward 返回了一个无效的梯度
问题描述
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 3)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = NeuralNetwork()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
def UploadData(path, train):
#set up transforms for train and test datasets
train_transforms = transforms.Compose([transforms.Grayscale(num_output_channels=1), transforms.Resize(255), transforms.CenterCrop(224), transforms.RandomRotation(30),
transforms.RandomHorizontalFlip(), transforms.transforms.ToTensor()])
valid_transforms = transforms.Compose([transforms.Grayscale(num_output_channels=1), transforms.Resize(255), transforms.CenterCrop(224), transforms.RandomRotation(30),
transforms.RandomHorizontalFlip(), transforms.transforms.ToTensor()])
test_transforms = transforms.Compose([transforms.Grayscale(num_output_channels=1), transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor()])
#set up datasets from Image Folders
train_dataset = datasets.ImageFolder(path + '/train', transform=train_transforms)
valid_dataset = datasets.ImageFolder(path + '/validation', transform=valid_transforms)
test_dataset = datasets.ImageFolder(path + '/test', transform=test_transforms)
#set up dataloaders with batch size of 32
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
validloader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True)
return trainloader, validloader, testloader
trainloader, validloader, testloader = UploadData("/home/lns/research/dataset", True)
epochs = 5
min_valid_loss = np.inf
for e in range(epochs):
train_loss = 0.0
for data, labels in trainloader:
# Transfer Data to GPU if available
if torch.cuda.is_available():
print("using GPU for data")
data, labels = data.cuda(), labels.cuda()
# Clear the gradients
optimizer.zero_grad()
# Forward Pass
target = net(data)
# Find the Loss
loss = criterion(target,labels)
# Calculate gradients
loss.backward()
# Update Weights
optimizer.step()
# Calculate Loss
train_loss += loss.item()
valid_loss = 0.0
model.eval() # Optional when not using Model Specific layer
for data, labels in validloader:
# Transfer Data to GPU if available
if torch.cuda.is_available():
print("using GPU for data")
data, labels = data.cuda(), labels.cuda()
# Forward Pass
target = net(data)
# Find the Loss
loss = criterion(target,labels)
# Calculate Loss
valid_loss += loss.item()
print('Epoch ',e+1, '\t\t Training Loss: ',train_loss / len(trainloader),' \t\t Validation Loss: ',valid_loss / len(validloader))
if min_valid_loss > valid_loss:
print("Validation Loss Decreased(",min_valid_loss,"--->",valid_loss,") \t Saving The Model")
min_valid_loss = valid_loss
# Saving State Dict
torch.save(net.state_dict(), '/home/lns/research/MODEL.pth')
经过大量搜索后,我正在寻求帮助。有人可以帮助我 理解为什么在反向传播中会发生此错误。 我跟着pytorch cnn tutorialail和geeksforgeeks 教程 数据集是 x 射线图像转换为灰度并调整为 255 我的神经网络是错误的还是数据处理不正确?
解决方案
这是 CNN 的输出与第一个全连接层上的神经元数量之间的大小不匹配。由于缺少填充,展平时的元素数量是16*4*4
ie 256
(而不是16*5*5
):
self.fc1 = nn.Linear(256, 120)
修改后,模型将正确运行:
>>> model = NeuralNetwork()
>>> model(torch.rand(1, 1, 28, 28)).shape
torch.Size([1, 3])
或者,您可以使用nn.LazyLinear
将in_feature
在第一次推理期间根据其输入形状推断参数。
self.fc1 = nn.LazyLinear(120)
推荐阅读
- image - SwiftUI URLImage 水平滚动视图
- react-native - React Native - Nodejs Express 社交登录
- php - WordPress / PHP'尝试从远程网址上传图像时尝试获取非对象的属性'提要'
- angular - 如何使用 Angular 中的 firestore 从一个引用中获取一份文档的值?
- elasticsearch - 如何使用 Elasticsearch 对文本输入执行部分单词搜索?
- sql-server - TRY_PARSE 表示服务器之间的科学记数法不同
- json - 如何使用 json 文件与 testcafe 中的多个 webelments 交互
- python - 最大错误答案后如何停止循环?
- sql-server - Datediff 与连接返回非预期结果
- javascript - 使用加载器 ts-node/esm.js 运行节点需要导入具有 .js 扩展名