computer-vision - 从头开始训练网络的准确性极差
问题描述
我正在尝试使用类似于 ImageNet 的数据集从头开始重新训练 resnet50。我写了以下训练循环:
def train_network(epochs , train_loader , val_loader , optimizer , network):
since = time.time ( )
train_acc_history = []
val_acc_history = []
best_model_weights = copy.deepcopy (network.state_dict ( ))
best_accuracy = 0.0
for epoch in range (epochs):
correct_train = 0
correct_val = 0
for x , t in train_loader:
x = x.to (device)
t = t.to (device)
optimizer.zero_grad ( )
z = network (x)
J = loss (z , t)
J.backward ( )
optimizer.step ( )
_ , y = torch.max (z , 1)
correct_train += torch.sum (y == t.data)
with torch.no_grad ( ):
network.eval ( )
for x_val , t_val in val_loader:
x_val = x_val.to (device)
t_val = t_val.to (device)
z_val = network (x_val)
_ , y_val = torch.max (z_val , 1)
correct_val += torch.sum (y_val == t_val.data)
network.train ( )
train_accuracy = correct_train.float ( ) / len (train_loader.dataset)
val_accuracy = correct_val.float ( ) / len (val_loader.dataset)
print (
F"Epoch: {epoch + 1} train_accuracy: {(train_accuracy.item ( ) * 100):.3f}% val_accuracy: {(val_accuracy.item ( ) * 100):.3f}%" ,
flush = True)
# time_elapsed_epoch = time.time() - since
# print ('Time taken for Epoch {} is {:.0f}m {:.0f}s'.format (epoch + 1, time_elapsed_epoch // 60 , time_elapsed_epoch % 60))
if val_accuracy > best_accuracy:
best_accuracy = val_accuracy
best_model_weights = copy.deepcopy (network.state_dict ( ))
train_acc_history.append (train_accuracy)
val_acc_history.append (val_accuracy)
print ( )
time_elapsed = time.time ( ) - since
print ('Training complete in {:.0f}m {:.0f}s'.format (time_elapsed // 60 , time_elapsed % 60))
print ('Best Validation Accuracy: {:3f}'.format (best_accuracy * 100))
network.load_state_dict (best_model_weights)
return network , train_acc_history , val_acc_history
但是我的训练和验证准确性极差,如下所示:
> Epoch: 1 train_accuracy: 3.573% val_accuracy: 3.481%
> Epoch: 2 train_accuracy: 3.414% val_accuracy: 3.273%
> Epoch: 3 train_accuracy: 3.515% val_accuracy: 4.039%
> Epoch: 4 train_accuracy: 3.567% val_accuracy: 4.195%
谷歌搜索后,我发现从头开始训练的准确度通常不会那么差(实际上它们从大约 40% - 50% 开始)。我发现很难理解故障可能在哪里。如果有人可以帮助我找出我可能出错的地方,那就太好了。
谢谢
解决方案
我在没有权重检查点的情况下尝试了您的训练循环,并使用我自己的 ResNet 在时尚MNIST 数据集上获得了超过 90% 的准确度。因此,如果您使用的是良好的损失/优化器,我建议您查看网络架构或数据加载器的创建。
def train_network(epochs , train_loader , val_loader , optimizer , network):
#since = time.time ( )
train_acc_history = []
val_acc_history = []
loss = nn.CrossEntropyLoss()
#best_model_weights = copy.deepcopy (network.state_dict ( ))
#best_accuracy = 0.0
for epoch in range (epochs):
correct_train = 0
correct_val = 0
network.train ( )
for x , t in train_loader:
x = x.to (device)
t = t.to (device)
optimizer.zero_grad ( )
z = network (x)
J = loss (z , t)
J.backward ( )
optimizer.step ( )
_ , y = torch.max (z , 1)
correct_train += torch.sum (y == t.data)
with torch.no_grad ( ):
network.eval ( )
for x_val , t_val in val_loader:
x_val = x_val.to (device)
t_val = t_val.to (device)
z_val = network (x_val)
_ , y_val = torch.max (z_val , 1)
correct_val += torch.sum (y_val == t_val.data)
network.train ( )
train_accuracy = correct_train.float ( ) / len (train_loader.dataset)
val_accuracy = correct_val.float ( ) / len (val_loader.dataset)
print (
F"Epoch: {epoch + 1} train_accuracy: {(train_accuracy.item ( ) * 100):.3f}% val_accuracy: {(val_accuracy.item ( ) * 100):.3f}%" ,
flush = True)
'''
if val_accuracy > best_accuracy:
best_accuracy = val_accuracy
best_model_weights = copy.deepcopy (network.state_dict ( ))
train_acc_history.append (train_accuracy)
val_acc_history.append (val_accuracy)
#time_elapsed = time.time ( ) - since
#print ('Training complete in {:.0f}m {:.0f}s'.format (time_elapsed // 60 , time_elapsed % 60))
print ('Best Validation Accuracy: {:3f}'.format (best_accuracy * 100))
#network.load_state_dict (best_model_weights)
'''
return network , train_acc_history , val_acc_history
optimizer = optim.Adam(net.parameters(), lr = 0.01)
train_network(10,trainloader, testloader, optimizer, net)
Epoch: 1 train_accuracy: 83.703% val_accuracy: 86.820%
Epoch: 2 train_accuracy: 88.893% val_accuracy: 89.400%
Epoch: 3 train_accuracy: 90.297% val_accuracy: 89.700%
Epoch: 4 train_accuracy: 91.272% val_accuracy: 90.640%
Epoch: 5 train_accuracy: 91.948% val_accuracy: 91.250%
...
因此,如果您使用我使用的训练循环进行测试(您使用的是小型模组),但它仍然无法正常工作,我会检查数据加载器并使用网络架构。
推荐阅读
- r - R 不断为我的随机森林模型崩溃,因为我想种植 10,000 棵树并且 do.omit 为 1,000
- tinymce-4 - TinyMCE 4.9.6 TinyMce.min.js 文件获取错误消息
- gremlin - Gremlin Python:在一个查询中计算顶点及其子项
- react-native - 在 react-native 0.60.5 中,无法从内部渲染方法调用类函数
- python - 如何将颜色从 HSV 更改回 RGB
- c# - 绑定到 System.Windows.Forms.TreeNode 集合的 WPF TreeView 不显示任何内容
- python - 如何找到每个数字中有多少个数字Python
- dialogflow-es - 如何解决与 Play Store 相关的问题
- dataframe - 在 spark sql 中应用 isNll 后,实际值也被替换
- python-3.x - rgb的整数列表到图像python