python - Pytorch:RuntimeError:预期的 dtype Float 但得到了 dtype Long
问题描述
我在 Pytorch 中构建一个简单的 NN 时遇到了这个奇怪的错误。我不明白这个错误以及为什么在后向函数中会担心 Long 和 Float 数据类型。有人遇到过这种情况吗?谢谢你的帮助。
Traceback (most recent call last):
File "test.py", line 30, in <module>
loss.backward()
File "/home/liuyun/anaconda3/envs/torch/lib/python3.7/site-packages/torch/tensor.py", line 198, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/liuyun/anaconda3/envs/torch/lib/python3.7/site-packages/torch/autograd/__init__.py", line 100, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: expected dtype Float but got dtype Long (validate_dtype at /opt/conda/conda-bld/pytorch_1587428398394/work/aten/src/ATen/native/TensorIterator.cpp:143)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x4e (0x7f5856661b5e in /home/liuyun/anaconda3/envs/torch/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: at::TensorIterator::compute_types() + 0xce3 (0x7f587e3dc793 in /home/liuyun/anaconda3/envs/torch/lib/python3.7/site
-packages/torch/lib/libtorch_cpu.so)
frame #2: at::TensorIterator::build() + 0x44 (0x7f587e3df174 in /home/liuyun/anaconda3/envs/torch/lib/python3.7/site-packages
/torch/lib/libtorch_cpu.so)
frame #3: at::native::smooth_l1_loss_backward_out(at::Tensor&, at::Tensor const&, at::Tensor const&, at::Tensor const&, long)
+ 0x193 (0x7f587e22cf73 in /home/liuyun/anaconda3/envs/torch/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0xe080b7 (0x7f58576960b7 in /home/liuyun/anaconda3/envs/torch/lib/python3.7/site-packages/torc
h/lib/libtorch_cuda.so)
frame #5: at::native::smooth_l1_loss_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, long) + 0x16e (0x7f587
e23569e in /home/liuyun/anaconda3/envs/torch/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so)
frame #6: <unknown function> + 0xed98af (0x7f587e71c8af in /home/liuyun/anaconda3/envs/torch/lib/python3.7/site-packages/torc
h/lib/libtorch_cpu.so)
frame #7: <unknown function> + 0xe22286 (0x7f587e665286 in /home/liuyun/anaconda3/envs/torch/lib/python3.7/site-packages/torc
h/lib/libtorch_cpu.so)
这是源代码:
import torch
import torch.nn as nn
import numpy as np
import torchvision
from torchvision import models
from UTKLoss import MultiLoss
from ipdb import set_trace
# out features [13, 2, 5]
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 20)
model_ft.cuda()
criterion = MultiLoss()
optimizer = torch.optim.Adam(model_ft.parameters(), lr = 1e-3)
image = torch.randn((1, 3, 128, 128)).cuda()
age = torch.randint(110, (1,)).cuda()
gender = torch.randint(2, (1,)).cuda()
race = torch.randint(5, (1,)).cuda()
optimizer.zero_grad()
output = model_ft(image)
age_loss, gender_loss, race_loss = criterion(output, age, gender, race)
loss = age_loss + gender_loss + race_loss
loss.backward()
optimizer.step()
这是我定义的损失函数
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, output, age, gender, race):
age_pred = output[:, :13]
age_pred = torch.sum(age_pred, 1)
gender_pred = output[:, 13: 15]
race_pred = output[:, 15:]
age_loss = F.smooth_l1_loss(age_pred.view(-1, 1), age.cuda())
gender_loss = F.cross_entropy(gender_pred, torch.flatten(gender).cuda(), reduction='sum')
race_loss = F.cross_entropy(race_pred, torch.flatten(race).cuda(), reduction='sum')
return age_loss, gender_loss, race_loss
解决方案
Change the criterion
call to:
age_loss, gender_loss, race_loss = criterion(output, age.float(), gender, race)
If you look at your error we can trace it to:
frame #3: at::native::smooth_l1_loss_backward_out
In the MultiLoss Class, the smooth_l1_loss
works with age
. So I changed it's type to float (as the expected dtype is Float) while passing it to the criterion
. You can check that age is torch.int64
(i.e. torch.long
) by printing age.dtype
I am not getting the error after doing this. Hope it helps.
推荐阅读
- javascript - 我的 ES6 导入是从索引文件还是直接从导出文件读取?
- ios - 快速获取待处理本地通知的剩余时间的问题
- c# - c# version 8 Target-typed new-expressions error requires (), [], or {} after type
- performance - 电源查询。在保持输入灵活的同时提高性能
- c# - 多个类的 C# 通用参数
- node.js - 有什么方法可以在一次 Mongo 往返中执行顺序查询?
- excel-formula - openpyxl:保存到文件时将“@”插入到公式中
- wolfram-mathematica - 为什么我要使用 WolframScript 而不是 mathematica 或反之亦然?
- javascript - 我可以刷新(一组)对象的值,因为它们引用的变量发生了变化吗?
- node.js - 使用 sequelize 迁移时不会创建 Sequelize 特殊方法