python - 来自 sigmoid 函数的模型输出几乎等于 0.5 并且保持不变
问题描述
因此,我一直在针对二值图像分类问题训练 Inception-Resnet V2,在训练时我观察到 logits 不会收敛到 0 或 1。它们只会在 0.5 左右波动。似乎是什么错误?我在一个高度不平衡的数据集上训练了 4 个 epoch 的预训练模型,这就是为什么我也使用加权随机采样器。批量大小为 128,优化器为 adam,学习率为 0.001
device = "cuda"
epochs=4
print("======== Training for ", epochs, "epochs=============")
for epoch in range(epochs):
total_loss = 0
model.train()
print("Training.......")
print("======== EPOCH #",epoch,"=================")
tmp_acc = 0
for i,batch in enumerate(train_loader):
img,label = batch["images"],batch["labels"]
label = label.type(torch.FloatTensor)
img,label = img.to(device),label.to(device)
model.zero_grad()
op,aux = model(img)
label_cpu = label.cpu().numpy()
op = F.sigmoid(op)
output = op.detach().cpu().numpy()
tmp_acc += accuracy_score(output,label_cpu)
loss = criterion(op,label)
total_loss = loss.item()
loss.backward()
adam.step()
if(i%10==0 and i>0):
print("STEP: ",i, "of steps ",len(train_loader))
print("Current loss: ",total_loss/i)
print("Training Accuracy ",tmp_acc/i)
print("OP",op)
print("Label",label_cpu)
avg_loss = total_loss/len(train_loader)
print("The loss after ",epoch," epochs is ",avg_loss)
model.eval()
print("Validating.....")
tmp_accuracy = 0
z_count,o_count=0,0
z_count_truth,o_count_truth = 0,0
for i,batch in enumerate(val_loader):
img,label = batch["images"],batch["labels"]
img = img.to(device)
with torch.no_grad():
op = F.sigmoid(model(img))
op_cpu = op.detach().cpu().numpy()
label = label.numpy()
tmp_accuracy += accuracy_score(op_cpu,label)
z_count += np.sum(op_cpu==0)
o_count += np.sum(op_cpu==1)
z_count_truth += np.sum(label==0)
o_count_truth += np.sum(label==1)
percent_correct_z = z_count/z_count_truth
percent_correct_o = o_count/o_count_truth
accuracy = tmp_accuracy/len(val_loader)
print("Accuracy: ", "is ",accuracy)
#print("Percent of correct zero labels ",percent_correct_z)
#print("Percent of correct one labels ",percent_correct_o)```
输出看起来像
STEP: 90 of steps 99
Current loss: 0.007694996065563626
Training Accuracy 0.5019965277777778
OP tensor([[0.4962],
[0.4956],
[0.4950],
[0.4957],
[0.4945],
[0.4957],
[0.4952],
[0.4965],
[0.4950],
[0.4962],
[0.4956],
[0.4956],
[0.4951],
[0.4953],
[0.4956],
[0.4958],
[0.4949],
[0.4945],
[0.4955],
[0.4924],
[0.4952],
[0.4952],
[0.4958],
[0.4953],
[0.4959],
[0.4952],
[0.4965],
[0.4956],
[0.4956],
[0.4381],
[0.4951],
[0.4946],
[0.4957],
[0.4951],
[0.4955],
[0.4952],
[0.4955],
[0.4948],
[0.4951],
[0.4960],
[0.4956],
[0.4955],
[0.4958],
[0.4957],
[0.4953],
[0.4954],
[0.4955],
[0.4959],
[0.4949],
[0.4960],
[0.4953],
[0.4949],
[0.4951],
[0.4952],
[0.4949],
[0.4954],
[0.4956],
[0.4951],
[0.4947],
[0.4958],
[0.4953],
[0.4960],
[0.4959],
[0.4958],
[0.4948],
[0.4947],
[0.4957],
[0.4961],
[0.4955],
[0.4959],
[0.4955],
[0.4954],
[0.4959],
[0.4952],
[0.4955],
[0.4951],
[0.4962],
[0.4961],
[0.4961],
[0.4960],
[0.4956],
[0.4959],
[0.4953],
[0.4960],
[0.4955],
[0.4949],
[0.4958],
[0.4953],
[0.4955],
[0.4959],
[0.4951],
[0.4961],
[0.4939],
[0.4954],
[0.4953],
[0.4958],
[0.4953],
[0.4949],
[0.4959],
[0.4958],
[0.4960],
[0.4949],
[0.4957],
[0.4964],
[0.4949],
[0.4956],
[0.4952],
[0.4959],
[0.4954],
[0.4958],
[0.4954],
[0.4951],
[0.4953],
[0.4953],
[0.4958],
[0.4954],
[0.4955],
[0.4954],
[0.4960],
[0.4946],
[0.4950],
[0.4953],
[0.4957],
[0.4956],
[0.4954],
[0.4940],
[0.4951],
[0.4955]], device='cuda:0', grad_fn=<SigmoidBackward>)
Label [0. 1. 0. 1. 1. 0. 1. 0. 1. 1. 1. 0. 0. 0. 0. 1. 0. 1. 1. 0. 1. 0. 1. 0.
1. 0. 0. 1. 1. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 1. 1. 1. 0. 1. 1.
0. 0. 1. 1. 1. 0. 1. 0. 0. 0. 1. 1. 1. 1. 0. 1. 0. 1. 1. 0. 0. 0. 1. 1.
0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 1. 0. 1. 0. 0. 0. 1. 0. 1. 0. 0. 1. 1. 0.
0. 1. 1. 0. 0. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0.
1. 0. 1. 1. 0. 1. 1. 0.]
OP Tensor 对应输出 logits 标签对应原始标签
解决方案
既然你有
op = F.sigmoid(op)
并且所有的输出值都在 左右0.5
,从 sigmoid 函数的定义看来,它的所有输入值都非常接近0
。这意味着很可能您的所有输入图像都为零,或者网络的权重没有正确初始化。由于 resnet v2 有很多跳过连接,因此问题似乎更有可能来自您的输入图像。
作为一般的初始测试,查看您的网络是否能够过度拟合非常小的数据集通常很有用。在您的情况下,我会首先尝试将单个图像过度拟合到其标签上,这将使调试比您当前的批量大小和改组更容易。
推荐阅读
- sql - 从 SQL 中的 URL 中提取整数 ID
- azure - Cosmos DB IN 子句通过 REST API
- react-admin - 是否可以在 react-admin 的“undoable={false}”设置的删除确认中控制“取消”按钮的路径?
- javascript - 谷歌脚本根据单元格值或日期范围添加多行
- javascript - 使用 Testcafe 访问自定义窗口属性
- logging - Elixir:如何将记录器连接到特定的记录器后端
- paperjs - 缩放和平移修复
- c++ - 为什么不转发参考常量?
- javascript - 如何使用 js 解析 XMLHttpRequest 的 HTML 响应文本?
- python - 计算每行的“最好的 4”平均值