首页 > 解决方案 > Pytorch:创建一个大于批次中每个 2D 张量的第 n 个分位数的掩码

问题描述

我有一个torch.Tensor形状(2, 2, 2)(可以更大),其中值在 range 内标准化[0, 1]

现在我得到一个正整数K,它告诉我需要创建一个掩码,其中对于批次中的每个 2D 张量,如果它大于1/k所有值,则值为 1,而在其他地方为 0。返回面罩也有形状(2, 2, 2)

例如,如果我有这样的批次:

tensor([[[1., 3.],
         [2., 4.]],
        [[5., 7.],
         [9., 8.]]])

let K=2,这意味着我必须屏蔽大于每个 2D 张量内所有值的 50% 的值。

在示例中,0.5 分位数是2.57.5,因此这是所需的输出:

tensor([[[0, 1],
         [0, 1]],
        [[0, 0],
         [1, 1]]])

我试过了:

a = torch.tensor([[[0, 1],
                   [0, 1]],
                  [[0, 0],
                   [1, 1]]])
quantile = torch.tensor([torch.quantile(x, 1/K) for x in a])
torch.where(a > val, 1, 0)

但这是结果:

tensor([[[0, 0],
         [0, 0]],
        [[1, 0],
         [1, 1]]])

标签: pythonpytorch

解决方案


t = torch.tensor([[[1., 3.],
         [2., 4.]],
        [[5., 7.],
         [9., 8.]]])

t_flat = torch.reshape(t, (t.shape[0], -1))
quants = torch.quantile(t_flat, 1/K, dim=1)
quants = torch..reshape(quants, (quants.shape[0], 1, 1))
res = torch.where(t > val, 1, 0)

在这个 res 之后是:

tensor([[[0, 1],
         [0, 1]],

        [[0, 0],
         [1, 1]]])

这就是你想要的


推荐阅读