首页 > 解决方案 > 避免使用 `torch.where` 组合张量的梯度破坏

问题描述

这是一个非常简短的示例,说明了我遇到的问题以及我所追求的

nan = float("nan")
t = th.Tensor([[nan, nan], [0, 0], [.1, .1],])

l = nn.Linear(2, 2)
tf = l(t)

tf = th.where(t != t, th.zeros_like(t), tf)
(10 - tf.sum()).backward()
print(l.weight.grad)

为什么梯度是nan(尤其是当以下就可以了)

nan = float("nan")
t = th.Tensor([[nan, nan], [0, 0], [.1, .1],])
t_o = th.where(t != t, th.zeros_like(t), t)

l = nn.Linear(2, 2)
tf = l(t_o)

(10 - tf.sum()).backward()
print(l.weight.grad)

给出正确的梯度:

tensor([[-0.1000, -0.1000],
        [-0.1000, -0.1000]])

完整的问题

假设我有t形状张量,(n, f_i)每个张量f_i可能不同t。此外,这些t张量中的每一个都可能有一些nan行(假设每个张量代表一种对象,我想要所有这些对象的共享嵌入)。我想做的是将这些张量“同化”到一个共享空间中。

这是尝试执行此操作的最小模块:

import torch as th
import torch.nn as n

class Assimilation(nn.Module):
    def __init__(self, assimilation_dim: int, *type_in_dim):
        super().__init__()
        self._assimilation_dim = assimilation_dim
        self._type_in_dim = type_in_dim
        self._mixins = nn.ModuleList(
            [
                nn.Linear(in_dim, self._assimilation_dim, bias=False)
                for in_dim in self._type_in_dim
            ]
        )

    def forward(self, *x: th.Tensor):
        assims, = []
        for to_assim, assim_layer in zip(x, self._mixins):
            assimilated = assim_layer(to_assim)
            assims.append(assimilated)

        # the problem area (I believe)
        current = assims[0]
        for i in range(1, len(assims)):
            current = th.where(current != current, assims[i], current)
        return current

请注意,该实现有更多要求,例如允许更多任意形状的张量(*, n, f_i)用于更多前导维度。出于这个原因,我想避免跟踪索引,因为在批处理此类数据(或添加时间等)时事情变得更加复杂 - 特别是因为 torch的花式索引不好

但是,当尝试使用它时,您可以观察到渐变从th.where

ass = Assimilation(2, 2, 3) # :)
nan = float("nan")

t1 = th.Tensor([[nan, nan], [0, 0], [.1, .1],])
t2 = th.Tensor([[100, 100, 100], [nan, nan, nan], [nan, nan, nan],])

res = ass(t1, t2)
print(t1.shape, t2.shape)
print(res.shape)
print(res)
(10 - res.sum()).backward()
print(ass._mixins[0].weight.grad) # all nan 
print(ass._mixins[1].weight.grad) # all nan

是否有理由th.where在这里打破渐变?我怎样才能避免这种情况?

标签: pythonpython-3.xdeep-learningpytorchtensor

解决方案


推荐阅读