首页 > 解决方案 > PyTorch "where" 有条件的 -- RuntimeError: 预期的标量类型 long long 但发现 float

问题描述

我有 2 个带渐变的张量:

a = tensor([[0.0000, 0.0000, 0.2716, 0.0000, 0.4049, 0.0000, 0.2126, 0.8649, 0.0000,
         0.0000]], grad_fn=<ReluBackward0>)

b = tensor([[0.5842, 0.4618, 0.4047, 0.5714, 0.4841, 0.5683, 0.4030, 0.3779, 0.4436,
         0.4365]], grad_fn=<SigmoidBackward>)

我正在尝试使用第二个张量 ( b) 作为阈值,同时保持张量的可微性:

torch.where(a < b, 0, a)

但是,我收到一个错误

RuntimeError: expected scalar type long long but found float

我可以将张量转换为longwith

a = torch.tensor([0.0000, 0.0736, 0.5220, 0.0000, 0.0000, 0.1783, 0.0000, 0.0000, 0.0000,
         0.0000]).type(torch.LongTensor)

b = torch.tensor([0.4596, 0.4635, 0.5073, 0.4358, 0.5551, 0.5089, 0.5348, 0.5573, 0.5656,
         0.5886]).type(torch.LongTensor)

然后条件操作可以正常工作:

torch.where(a < b, 0, a)

虽然 1. 它给了我错误的答案(它只是将每个张量转换为零):

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

2.它损失了梯度。

我还尝试了 2 个简单的张量:

a = torch.tensor([1,2,3,4])
threshold = torch.tensor([0.5,2.3,2.9,4.2])

torch.where(a < threshold, 0, a)

>>>tensor([1, 0, 3, 0])

这似乎有效(尽管在这种情况下我没有关于渐变的参考,也不知道为什么它在这种情况下有效,而不是在其他情况下有效,因为我需要第一个工作)

标签: pythonpytorchconditional-statementsdifferentiation

解决方案


推荐阅读