python - RuntimeError: _thnn_mse_loss_forward 没有为类型 torch.cuda.LongTensor 实现
问题描述
我正在使用 PyTorch,但出现错误!我的错误代码如下:
for train_data in trainloader:
example_count += 1
if example_count == 100:
break
optimer.zero_grad()
image, label = train_data
image = image.cuda()
label = label.cuda()
out = model(image)
_, out = torch.max(out, 1)
# print(out.cpu().data.numpy())
# print(label.cpu().data.numpy())
# out = torch.zeros(4, 10).scatter_(1, out.cpu(), 1).cuda()
# label= torch.zeros(4, 10).scatter_(1, label.cpu(), 1).cuda()
l = loss(out, label)
l.bakeward()
optimer.setp()
j += 1
count += label.size(0)
acc += (out == label).sum().item()
if j % 1000 == 0:
print(j + ' step:curent accurity is %f' % (acc / count))
回溯:
Traceback (most recent call last):
File "VGG实现.py", line 178, in <module>
utils.train(testloader,model)
File "VGG实现.py", line 153, in train
l=loss(out,label)
File "/home/tang/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/home/tang/anaconda3/lib/python3.7/site-packages/torch/nn/modules/loss.py", line 435, in forward
return F.mse_loss(input, target, reduction=self.reduction)
File "/home/tang/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 2156, in mse_loss
ret = torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
RuntimeError: _thnn_mse_loss_forward is not implemented for type torch.cuda.LongTensor
我得到一个答案,这里 Pytorch RuntimeError: "host_softmax" not implemented for 'torch.cuda.LongTensor'
但我不知道如何解决这个问题。
解决方案
查看以下文档torch.max()
:
torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
返回给定维度 dim 中输入张量的每一行的最大值。第二个返回值是找到的每个最大值的索引位置 (argmax)。
你的代码行
_, out = torch.max(out, 1)
采用模型的浮点预测,out
,并用于torch.max()
返回最大预测的argmax = type long int 索引。
您收到的错误消息是您的loss
函数(我猜您正在使用带有 softmax 的交叉熵)不支持长类型的第一个参数。
此外,您不能通过 argmax 进行导数 - 所以我认为转换out
为 float using.to(torch.float)
也不会对您有任何好处。
您正在使用的损失函数中的 softmax 函数正在为您处理 argmax。
推荐阅读
- forms - Google 应用程序与 Google 搜索引擎的隔离
- deep-learning - 我训练了 pix2code,但无论给出什么图像,它总是输出相同的 DSL 内容
- javascript - 将 JavaScript 对象作为字符串发送到
- python - ValueError:发现样本数量不一致的输入变量:[13, 26]
- image - React Native 应用程序中的所有图像/快速图像不适用于 iOS 14 测试版和 Xcode 12 测试版
- oauth - 如何在 PassportJS 中有变量回调?
- c - 在 Windows 中将数据从 USB/RS485 串行发送到 RS485 C 代码
- java - 缺少工件 org.apache.flink:flink-table:jar:1.10.1
- xamarin - 尝试聚焦时,Webview 文本字段滚动到顶部 - Xamarin.forms iOS
- javascript - 当 firebase 数据库更新时,Chart.js 不会动态更新。(Vue.js)