python - torch.nn.BCELoss() 的两个参数中的导数
问题描述
当使用torch.nn.BCELoss()
两个参数时,这两个参数都是一些早期计算的结果,我得到了一些奇怪的错误,这个问题是关于:
RuntimeError: the derivative for 'target' is not implemented
MCVE 如下:
import torch
import torch.nn.functional as F
net1 = torch.nn.Linear(1,1)
net2 = torch.nn.Linear(1,1)
loss_fcn = torch.nn.BCELoss()
x = torch.zeros((1,1))
y = F.sigmoid(net1(x)) #make sure y is in range (0,1)
z = F.sigmoid(net2(y)) #make sure z is in range (0,1)
loss = loss_fcn(z, y) #works if we replace y with y.detach()
loss.backward()
事实证明,如果我们调用.detach()
错误就会y
消失。但这会导致不同的计算,现在在-pass 中,不会计算.backward()
相对于第二个参数的梯度。BCELoss
谁能解释我在这种情况下做错了什么?据我所知,所有 pytorch 模块都torch.nn
应该支持计算梯度。这条错误消息似乎告诉我,导数没有实现y
,这有点奇怪,因为您可以计算 的梯度y
,但y.detach()
似乎不矛盾。
解决方案
看来我误解了错误信息。不是y
不允许计算梯度,而是BCELoss()
没有能力计算关于第二个参数的梯度。这里讨论了一个类似的问题。
推荐阅读
- javascript - 如何在 JavaScript 中制作计数动画?
- java - 如何使用 XPATH 从 XML 中检索特定元素的子元素?
- python - 使用 odeint 和 sympy 求解微分方程
- typescript - 我可以在许多单独的 typedoc 注释中共享链接的通用链接引用吗?
- react-native - 如何在底部导航的第一个标签页上拥有顶部导航和底部导航,这是一个单独的页面?
- java - 将 weblogic JMS 移动到 Oracle 高级队列时出错
- android - 如何在改造时发送对象列表
- android - 结合 Android 主题
- javascript - 如何使用动态表中的唯一 ID 通过按钮传递数据
- svelte - 无法在 Svelte 中使用 Fabric-js