python - pytorch nn.linear 中的 forward 给出 NaN
问题描述
我正在研究 Pytorch 模型。训练集如下所示:
Pclass Sex Age SibSp Parch Fare Embarked
0 3 0 1 1 0 0 1.0
1 1 1 1 1 0 3 2.0
2 3 1 1 0 0 0 1.0
3 1 1 1 1 0 3 1.0
4 3 0 1 0 0 0 1.0
... ... ... ... ... ... ... ...
886 2 0 1 0 0 0 1.0
887 1 1 0 0 0 1 1.0
888 3 1 0 1 2 0 1.0
889 1 0 1 0 0 1 2.0
890 3 0 1 0 0 0 0.0
891 rows × 7 columns
类似的是测试集。training_label 是一种热编码。
型号为:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(7, 10)
self.bc1 = nn.BatchNorm1d(10)
self.fc2 = nn.Linear(10, 10)
self.bc2 = nn.BatchNorm1d(10)
self.fc3 = nn.Linear(10, 1)
self.bc3 = nn.BatchNorm1d(1)
def forward(self, x):
x = self.fc1(x)
x = self.bc1(x)
x = F.relu(x)
x = self.fc2(x)
x = self.bc2(x)
x = F.relu(x)
x = self.fc3(x)
x = self.bc3(x)
#x = F.sigmoid(x)
return x
model = Net()
我正在调用 TensorDataset 和 DataLoader 来加载批量大小为 100 的 (train_data, train_label)。
我正在尝试调用模型:
num_epochs = 1
# Repeat for given number of epochs
for epoch in range(num_epochs):
# Train with batches of data
for xb,yb in train_dl:
pred = model(xb)
yb = torch.unsqueeze(yb, 1)
print(pred, yb)
#print('grad', model.fc1.weight.grad)
l = loss(pred, yb)
print('loss',l)
# 3. Compute gradients
l.backward()
# 4. Update parameters using gradients
optimizer.step()
# 5. Reset the gradients to zero
optimizer.zero_grad()
一段时间后,我得到 NaN 作为 pred = model(xb) 的输出。正如你所看到的,我只运行了 1 个 epoch,所以我在第一个 epoch 中获得了一些批次的 NaN。
我不确定为什么会这样。我检查并找到了一些解决方案,比如减少纪元,但我的纪元已经很低了。我还对数据集进行了预处理,以便所有输入都是 0、1、2、3 或 4(就像我预处理了年龄、性别等)。
我没有收到任何错误消息,但我可以看到输出为 NaN,当然,当我打印它们时,损失也是如此。我也尝试了不同的优化,损失。现在我正在使用:
#model.cuda()
#loss = nn.BCEWithLogitsLoss()
loss = F.mse_loss
learning_rate = 0.0001
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)
我也试过 BCEWithLogitsLoss。nn.BCELoss 给出了输入应该在 0 和 1 之间的错误(这里有些输入变成了 Nan)
任何人都可以帮助解释发生了什么吗?
解决方案
推荐阅读
- python - 如何使用 Python 将列表转换为多列?
- visual-studio - 代码 EXXXX 的 Visual C++ 错误是什么?
- objective-c - AVCaptureDeviceInput 在第一秒运行 AVCaptureSession 后丢弃帧,使用 nativescript
- algorithm - 两个节点之间的非循环路径的联合
- python-3.x - 嵌套if进入python中的嵌套for
- visual-studio - 新的 .NET Core 项目 (2.2) 无法编译
- go - 尝试减小 Go 程序的可执行文件大小
- php - 致命错误:未捕获的错误:在 bool 上调用成员函数 fetch()
- php - “symfony serve”在 php 7.2 (docker/alpine) 中崩溃
- android - Android Studio:从应用程序打开谷歌地图时应用程序停止工作