python - 避免使用 `torch.where` 组合张量的梯度破坏
问题描述
这是一个非常简短的示例,说明了我遇到的问题以及我所追求的
nan = float("nan")
t = th.Tensor([[nan, nan], [0, 0], [.1, .1],])
l = nn.Linear(2, 2)
tf = l(t)
tf = th.where(t != t, th.zeros_like(t), tf)
(10 - tf.sum()).backward()
print(l.weight.grad)
为什么梯度是nan(尤其是当以下就可以了)
nan = float("nan")
t = th.Tensor([[nan, nan], [0, 0], [.1, .1],])
t_o = th.where(t != t, th.zeros_like(t), t)
l = nn.Linear(2, 2)
tf = l(t_o)
(10 - tf.sum()).backward()
print(l.weight.grad)
给出正确的梯度:
tensor([[-0.1000, -0.1000],
[-0.1000, -0.1000]])
完整的问题
假设我有t
形状张量,(n, f_i)
每个张量f_i
可能不同t
。此外,这些t
张量中的每一个都可能有一些nan
行(假设每个张量代表一种对象,我想要所有这些对象的共享嵌入)。我想做的是将这些张量“同化”到一个共享空间中。
这是尝试执行此操作的最小模块:
import torch as th
import torch.nn as n
class Assimilation(nn.Module):
def __init__(self, assimilation_dim: int, *type_in_dim):
super().__init__()
self._assimilation_dim = assimilation_dim
self._type_in_dim = type_in_dim
self._mixins = nn.ModuleList(
[
nn.Linear(in_dim, self._assimilation_dim, bias=False)
for in_dim in self._type_in_dim
]
)
def forward(self, *x: th.Tensor):
assims, = []
for to_assim, assim_layer in zip(x, self._mixins):
assimilated = assim_layer(to_assim)
assims.append(assimilated)
# the problem area (I believe)
current = assims[0]
for i in range(1, len(assims)):
current = th.where(current != current, assims[i], current)
return current
请注意,该实现有更多要求,例如允许更多任意形状的张量(*, n, f_i)
用于更多前导维度。出于这个原因,我想避免跟踪索引,因为在批处理此类数据(或添加时间等)时事情变得更加复杂 - 特别是因为 torch的花式索引不好。
但是,当尝试使用它时,您可以观察到渐变从th.where
ass = Assimilation(2, 2, 3) # :)
nan = float("nan")
t1 = th.Tensor([[nan, nan], [0, 0], [.1, .1],])
t2 = th.Tensor([[100, 100, 100], [nan, nan, nan], [nan, nan, nan],])
res = ass(t1, t2)
print(t1.shape, t2.shape)
print(res.shape)
print(res)
(10 - res.sum()).backward()
print(ass._mixins[0].weight.grad) # all nan
print(ass._mixins[1].weight.grad) # all nan
是否有理由th.where
在这里打破渐变?我怎样才能避免这种情况?
解决方案
推荐阅读
- ios - CocoaPods Firebase 最新版本无法更新
- gradle - 注意:使用 -Xlint:deprecation 重新编译以获取详细信息。(已解决)
- javascript - SharePoint 列表 CSR 默认呈现回退
- javascript - 异步 WebSocket JS 类
- deep-learning - 为什么在验证集上设置 shuffle=False 在混淆矩阵和分类报告中比 shuffle=True 提供更好的结果?
- google-apps-script - Google Web App 不会在开发脚本上记录任何内容
- python - 想要掌握这个 pyautogui 命令
- r - 在 R 中合并具有不同大小和条件的数据框
- c++ - 为什么我在使用 fwrite() 和 fread() 时遇到问题?
- mysql - 为什么 mySQL 工作台总是将二进制数据显示为 blob?