python - PyTorch:定义层不参与前向传播,但影响损失值
问题描述
最近在使用 Pytorch 做逻辑回归的简单实验时,遇到了一个令人困惑的现象。
问题是当我像这样固定随机种子时:
def set_seed(seed, cuda=True):
np.random.seed(seed)
torch.manual_seed(seed)
if cuda:
torch.cuda.manual_seed(seed)
并定义了以下具有 2 层的模型:
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.hidden = nn.Linear(784, 100)
self.output = nn.Linear(100, 10)
def forward(self, x):
x = self.hidden(x)
x = self.output(x)
return x
用以下方法训练网络:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
原始损失值为0.6422,可重现。
但是,当我添加一个未参与到转发过程中的附加层时,如下所示:
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.hidden = nn.Linear(784, 100)
self.output = nn.Linear(100, 10)
self.add = nn.Linear(10,10)
def forward(self, x):
x = self.hidden(x)
x = self.output(x)
return x
原来的损失值变为0.7431,不等于之前的损失值,模型性能同时下降。
我真的很想知道这其中的原因。谢谢!
解决方案
如果在计算损失之前有其他随机性来源(消耗 RNG 的东西),这是完全可以预料的。由于您没有提供Minimal, Reproducible Example,我猜您正在使用DataLoader
with shuffle=True
。在这种情况下,即使你不使用self.add
层,当你初始化它时,它也会消耗 RNG;因此导致样品的顺序不同。如果随机性来自带有 的 DataLoader shuffle=True
,您可以通过向 DataLoader 提供不同的 RNG 来控制它。像这样的东西:
import numpy as np
import torch
from torch import nn
import torchvision
from torchvision.transforms import ToTensor
def set_seed(seed, cuda=True):
np.random.seed(seed)
torch.manual_seed(seed)
if cuda:
torch.cuda.manual_seed(seed)
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.hidden = nn.Linear(784, 100)
self.output = nn.Linear(100, 10)
# self.add = nn.Linear(10, 10) # try with and without
def forward(self, x):
x = self.hidden(x)
x = self.output(x)
return x
set_seed(0)
m = net()
bs = 4
ds = torchvision.datasets.MNIST(root=".", train=True, transform=ToTensor(), download=True)
rng_dl = torch.Generator()
dl = torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True, num_workers=0, generator=rng_dl)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
for x, y in dl:
y_hat = m(x.view(bs, -1))
l = criterion(y_hat, y)
print(l)
exit()
请记住,它可能是其他几件事,例如数据增强和对依赖随机操作的函数的其他调用。如果您可以提供 MRE,我可以尝试给出更具体的答案。
推荐阅读
- timestamp - 在某个时间戳停止 rosbag 以与修改后的包进行比较
- python - 我最近更改了编译器路径以运行 c++ 代码,但现在我无法运行任何 python 代码。我该如何解决?
- javascript - 为什么这是有效的 javascript?那些括号在做什么?(不是匿名函数)
- python - KeyError: 0 ,同时运行 while 循环
- python - 我使用 python 创建了一个排行榜,但我的代码打印在一行中。我希望它像排行榜格式一样单独打印每个元素
- r - 如何将性别列的值更改为 R 中的数值?
- php - Magento 2 -- debug.log -- main.DEBUG:操作“供应商\模块\控制器\索引\发布\拦截器”的请求验证失败
- java - 如何在 Java 中钳位值?
- java - 如何在 Android Studio 的 Java 库模块中使用 Dagger?
- c# - 如何在表面视图中从相机捕获图像