pytorch - 将 torch.nn.CrossEntropyLoss 与 label as int 和 prob 张量进行比较会产生错误
问题描述
我有这个标准:
criterion = nn.CrossEntropyLoss()
在训练阶段我有:
label = tensor([0.])
outputs = tensor([[0.0035, 0.0468]], grad_fn=<AddmmBackward>)
当我尝试比较它们时:
criterion(outputs, label)
我收到此错误:
ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward
解决方案
nn.CrossEntropyLoss
期望它的label
输入是 typetorch.Long
而不是torch.Float
.
请注意,此行为与预期与输入具有相同类型的nn.BCELoss
位置相反。target
如果您只是.
从标签中删除 :
label = torch.tensor([0]) # no . after 0 - now it is an integer
pytorch 会自动创建label
为torch.Long
类型,你没问题:
In [*]: criterion(outputs, torch.tensor([0]))
Out[*]: tensor(0.7150)
评论planet_pluto和Craig.Li的其他答案:
- 投射现有张量的更一般的方法是使用
.to(...)
:
label = torch.tensor([0]).to(dtype=torch.long)
但是,create-and-casting 并不是一种非常有效的处理方式:想想看,你让 pytroch 创建一个torch.float
张量,然后将其转换为torch.long
.
dtype
或者,您可以在创建张量时明确定义所需的:
label = torch.tensor([0.], dtype=torch.long)
这种方式 pytorch 创建label
所需的dtype
并且不需要第二阶段的铸造。
推荐阅读
- vba - 使用VBA从搜索网页的结果中提取表格
- javascript - 如何将“ref”添加到嵌入式嵌套反应组件
- reactjs - CoreUI-free-react-admin-template 不工作
- python - Google Foobar escape-pods 测试用例 N. 4 失败
- javascript - 导航中的间距
- reactjs - 如何只重新渲染一个单元格?
- html - Github页面上的损坏图像
- javascript - 如何在javascript / jquery中获取div的指定高度
- wordpress - 如何使用 wordpress 制作像 cambly 一样的网站
- python - 如何将模型字段添加到 django 模型