python - PyTorch "where" 有条件的 -- RuntimeError: 预期的标量类型 long long 但发现 float
问题描述
我有 2 个带渐变的张量:
a = tensor([[0.0000, 0.0000, 0.2716, 0.0000, 0.4049, 0.0000, 0.2126, 0.8649, 0.0000,
0.0000]], grad_fn=<ReluBackward0>)
b = tensor([[0.5842, 0.4618, 0.4047, 0.5714, 0.4841, 0.5683, 0.4030, 0.3779, 0.4436,
0.4365]], grad_fn=<SigmoidBackward>)
我正在尝试使用第二个张量 ( b
) 作为阈值,同时保持张量的可微性:
torch.where(a < b, 0, a)
但是,我收到一个错误
RuntimeError: expected scalar type long long but found float
我可以将张量转换为long
with
a = torch.tensor([0.0000, 0.0736, 0.5220, 0.0000, 0.0000, 0.1783, 0.0000, 0.0000, 0.0000,
0.0000]).type(torch.LongTensor)
b = torch.tensor([0.4596, 0.4635, 0.5073, 0.4358, 0.5551, 0.5089, 0.5348, 0.5573, 0.5656,
0.5886]).type(torch.LongTensor)
然后条件操作可以正常工作:
torch.where(a < b, 0, a)
虽然 1. 它给了我错误的答案(它只是将每个张量转换为零):
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
2.它损失了梯度。
我还尝试了 2 个简单的张量:
a = torch.tensor([1,2,3,4])
threshold = torch.tensor([0.5,2.3,2.9,4.2])
torch.where(a < threshold, 0, a)
>>>tensor([1, 0, 3, 0])
这似乎有效(尽管在这种情况下我没有关于渐变的参考,也不知道为什么它在这种情况下有效,而不是在其他情况下有效,因为我需要第一个工作)
解决方案
推荐阅读
- visual-studio - Visual Studio 2017 中的“在文件中查找”不起作用
- java - 如何使用java(spring boot)以json格式检索完整值(带小数的大整数值)?
- akavache - 在仅 aot 模式下运行时尝试 JIT 编译方法“Akavache.Sqlite3.Registrations:Register (Splat.IMutableDependencyResolver)”
- python - 字典在python中按月天数(键)排序
- bash - bash 等效于 zsh 的 '=' 命令路径扩展?
- reactjs - Reactadmin - 编辑用户角色
- java - 在 Spring Boot 中启动应用程序上下文时出错
- android - Android - Firebase Crashlytics - “无法检索设置”
- laravel - Laravel Firebase 多数据库支持
- javascript - 数据表 - M 脚本在第一个页面之外的任何其他页面上都不起作用,如何修复它?