deep-learning - 为什么我的损失在训练 10 个 epoch 时没有减少?
问题描述
我的硬件是带有 nvidia rtx 3060 gpu 的 Ryzen 5000 系列 cpu。我目前正在从事一项学校作业,涉及使用深度学习模型(在 PyTorch 中实现)从 CT 切片图像中预测 COVID 诊断。数据集可以在 GitHub 上的这个 url 上找到:https ://github.com/UCSD-AI4H/COVID-CT
我编写了一个自定义数据集,该数据集从数据集中获取图像并将其大小调整为 224x224。我还使用 skimage.color 将所有 rgba 或灰度图像转换为 rgb。其他变换包括随机水平和垂直翻转,以及 ToTensor()。为了评估模型,我使用了 sklearn.metrics 来计算模型的 AUC、F1 分数和准确度。
我的麻烦是我无法让模型训练。在 10 个 epoch 之后,损失并没有减少。我尝试调整优化器的学习率,但没有帮助。任何建议/想法将不胜感激。谢谢!
class RONANet(nn.Module):
def __init__(self, classifier_type=None):
super(RONANet, self).__init__()
self.classifier_type = classifier_type
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.classifier = self.compose_classifier()
self.conv_layers = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(32),
self.relu,
self.maxpool,
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
self.relu,
self.maxpool,
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
self.relu,
self.maxpool,
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
self.relu,
self.maxpool,
nn.AdaptiveAvgPool2d(output_size=(1,1)),
)
def compose_classifier(self):
if 'fc' in self.classifier_type:
classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(14**2*256, 256),
self.relu,
nn.Linear(256, 128),
self.relu,
nn.Linear(128, 2))
elif 'conv'in self.classifier_type:
classifier = nn.Sequential(
nn.Conv2d(256, 1, kernel_size=1, stride=1))
return classifier
def forward(self, x):
features = self.conv_layers(x)
out = self.classifier(features)
if 'conv' in self.classifier_type:
out = out.reshape([-1,])
return out
RONANetv1 = RONANet(classifier_type='conv')
RONANetv1 = RONANetv1.cuda()
RONANetv2 = RONANet(classifier_type='fc')
RONANetv2 = RONANetv2.cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(RONANetv1.parameters(), lr=0.1)
num_epochs = 100
best_auc = 0.5 # set threshold to random model performance
scores = {}
for epoch in range(num_epochs):
RONANetv1.train()
print(f'Current Epoch: {epoch+1}')
epoch_loss = 0
for images, labels in train_dataloader:
batch_loss = 0
optimizer.zero_grad()
with torch.set_grad_enabled(True):
images = images.cuda()
labels = labels.cuda()
out = RONANetv1(images)
loss = criterion(out, labels)
batch_loss += loss.item()
loss.backward()
optimizer.step()
epoch_loss += batch_loss
print(f'Loss this epoch: {epoch_loss}\n')
current_val_auc, current_val_f1, current_val_acc = get_scores(RONANetv1, val_dataloader)
if current_val_auc > best_auc:
best_auc = current_val_auc
torch.save(RONANetv1.state_dict(), 'RONANetv1.pth')
scores['AUC'] = current_val_auc
scores['f1'] = current_val_f1
scores['Accuracy'] = current_val_acc
print(scores)
.
Output:
Current Epoch: 1
Loss this epoch: 38.038745045661926
{'AUC': 0.6632183908045978, 'f1': 0.0, 'Accuracy': 0.4915254237288136}
Current Epoch: 2
Loss this epoch: 37.96312761306763
Current Epoch: 3
Loss this epoch: 37.93656861782074
Current Epoch: 4
Loss this epoch: 38.045261442661285
Current Epoch: 5
Loss this epoch: 38.01626980304718
Current Epoch: 6
Loss this epoch: 37.93017905950546
Current Epoch: 7
Loss this epoch: 37.913547694683075
Current Epoch: 8
Loss this epoch: 38.049841582775116
Current Epoch: 9
Loss this epoch: 37.95650988817215
解决方案
所以问题是你只训练分类器的第一部分而不是第二部分
# this
optimizer = torch.optim.Adam(RONANetv1.parameters(), lr=0.1)
# needs to become this
from itertools import chain
optimizer = torch.optim.Adam(chain(RONANetv1.parameters(), RONANetv2.parameters()))
你也需要在训练中加入其他 cnn
intermediate_out = RONANetv1(images)
out = RONANetv2(intermediate_out)
loss = criterion(out, labels)
batch_loss += loss.item()
loss.backward()
optimizer.step()
希望这有助于祝你好运!
推荐阅读
- java - 使用作用域 Storage MediaStore 在应用程序中共享意图问题
- postgresql - POSTGRESQL:加载 json 文件时出错“SQL 错误 [22P04]:ERREUR:données supplémentaires après la dernière Colonne 出席”
- html - IDK 为什么我的 mapbox 地图不是交互式的。县城应该可以点击
- django - 我的图像没有被可视化 Django
- c# - unity 3D 空中运动和跳跃问题
- amazon-athena - Amazon Athena 的“In Clause”语句中允许有多少个值
- c# - 有谁知道 AsyncPageable 是什么?
- c# - 向 contextmenustrip Windows 窗体中的项目添加功能
- java - Jackson 不会将新的 JSON 对象附加到现有的 Json 文件中
- python - Conda 包,纯粹但有入口点