首页 > 解决方案 > Torch - 插入缺失值

问题描述

我有一堆包含零的 NumOfImagesxHxW 形式的张量图像。我正在寻找一种仅使用同一图像中的信息(图像之间没有连接)来插入缺失值(零)的方法。有没有办法使用pytorch来做到这一点?

F.interpolate 似乎只适用于重塑。我需要填充零,同时保持张量的尺寸和梯度。

谢谢。

标签: pytorchinterpolation

解决方案


如果您只想要最近邻插值,则可以通过返回索引映射而不是插值图像来使@Yuri Feldman 的答案可区分。

我所做的是创建一个新类scipy.interpolate.NearestNDInterpolator并覆盖它的__call__方法。它只是返回索引而不是值。

from scipy.interpolate.interpnd import _ndim_coords_from_arrays

class NearestNDInterpolatorIndex(NearestNDInterpolator):
    def __init__(self, x, y, rescale=False, tree_options=None):
        NearestNDInterpolator.__init__(self, x, y, rescale=rescale, tree_options=tree_options)
        self.points = np.asarray(x)

    def __call__(self, *args):
        """
        Evaluate interpolator at given points.
        Parameters
        ----------
        xi : ndarray of float, shape (..., ndim)
            Points where to interpolate data at.
        """
        xi = _ndim_coords_from_arrays(args, ndim=self.points.shape[1])
        xi = self._check_call_shape(xi)
        xi = self._scale_x(xi)
        dist, i = self.tree.query(xi)
        return self.points[i]

然后,在 中fillMissingValues,而不是返回target_for_interp,我们返回这些:

source_indices = np.argwhere(invalid_mask)
target_indices = interp(source_indices)

return source_indices, target_indices

将新的插值器传递给fillMissingValues,那么我们可以得到图像的最近邻插值

img[..., source_indices[:, 0], source_indices[:, 1]] = img[..., target_indices[:, 0], target_indices[:, 1]]

假设图像大小在最后两个维度上。


编辑:这是不可区分的,因为我刚刚测试过。问题在于索引映射。我们需要使用掩码而不是就地操作,然后问题就解决了。


推荐阅读