首页 > 解决方案 > How to randomly set a variable number of elements in each row of a tensor in PyTorch

问题描述

I want to create a zero-one matrix of dimension (n, n). The ones should be placed randomly, with a cap on the number of ones in each row. Let us say I have a list of length n that has the value of cap for each of the n rows. How can I do this in PyTorch?

My question is similar to this previous question. The only change I am looking for is, there should be n values of k, corresponding to n rows.

标签: pythonpytorch

解决方案


正如@Marcel在上面的评论中所解释的那样,您可以首先将第一个m值设置为 value,k然后按置换索引进行索引,以获得一个 shuffle 张量:

>>> n = 10; m = 3; k = 1
>>> x = torch.zeros(n, n)

>>> x[:, :m] = k
tensor([[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.]])

用于torch.randperm获取逐行列排列:

>>> perm = torch.stack([torch.randperm(10) for _ in range(len(x))])
tensor([[8, 0, 3, 2, 1, 6, 9, 4, 5, 7],
        [5, 7, 1, 4, 8, 0, 6, 9, 2, 3],
        [2, 1, 9, 7, 0, 8, 6, 3, 5, 4],
        [1, 3, 5, 8, 7, 6, 9, 4, 2, 0],
        [7, 6, 0, 5, 2, 9, 1, 8, 4, 3],
        [5, 0, 6, 8, 1, 9, 2, 4, 3, 7],
        [4, 0, 6, 5, 8, 1, 3, 7, 2, 9],
        [5, 3, 4, 9, 0, 1, 7, 6, 8, 2],
        [5, 7, 9, 3, 2, 6, 8, 0, 4, 1],
        [2, 7, 4, 6, 3, 0, 9, 8, 5, 1]])

然后使用torch.gather索引张xperm

>>> x.gather(dim=0, index=perm)
tensor([[0., 1., 0., 1., 1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 1., 0., 0., 1., 0.],
        [1., 1., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 1., 0., 1., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 1., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0., 1., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1., 1., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1., 0., 0., 1., 0., 1.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 1.]])

或者,您可以torch.scatter直接使用value关键字参数:

>>> torch.zeros(n, n).scatter(dim=0, index=perm, value=1)
tensor([[0., 1., 0., 1., 1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 1., 0., 0., 1., 0.],
        [1., 1., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 1., 0., 1., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 1., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0., 1., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1., 1., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1., 0., 0., 1., 0., 1.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 1.]])

如果m是张量本身,您可以使用torch.arange和的组合找到解决方法torch.where

首先对位置进行编码:

>>> d = torch.arange(n)[None].repeat(n,1)
>>> x = torch.where(d+m>n, 0, 1)
tensor([[1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

像以前一样构造排列:

>>> perm = torch.stack([torch.randperm(10) for _ in range(n)])
tensor([[2, 5, 7, 0, 4, 1, 3, 6, 8, 9],
        [7, 4, 9, 5, 6, 0, 3, 1, 2, 8],
        [5, 1, 4, 9, 0, 3, 2, 6, 7, 8],
        [9, 6, 0, 2, 3, 1, 7, 5, 4, 8],
        [3, 5, 4, 6, 0, 7, 9, 8, 2, 1],
        [5, 7, 8, 6, 9, 2, 0, 4, 3, 1],
        [8, 3, 9, 0, 6, 2, 5, 7, 4, 1],
        [2, 9, 4, 3, 7, 8, 1, 0, 6, 5],
        [5, 4, 8, 3, 2, 9, 7, 1, 6, 0],
        [8, 7, 3, 6, 5, 4, 2, 0, 9, 1]])

然后分散在x

>>> x.scatter(dim=0, index=perm, value=1)
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 1],
        [1, 1, 1, 0, 0, 1, 1, 1, 0, 1],
        [1, 1, 1, 1, 1, 1, 1, 0, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

推荐阅读