python - How to randomly set a variable number of elements in each row of a tensor in PyTorch
问题描述
I want to create a zero-one matrix of dimension (n, n). The ones should be placed randomly, with a cap on the number of ones in each row. Let us say I have a list of length n that has the value of cap for each of the n rows. How can I do this in PyTorch?
My question is similar to this previous question. The only change I am looking for is, there should be n values of k, corresponding to n rows.
解决方案
正如@Marcel在上面的评论中所解释的那样,您可以首先将第一个m
值设置为 value,k
然后按置换索引进行索引,以获得一个 shuffle 张量:
>>> n = 10; m = 3; k = 1
>>> x = torch.zeros(n, n)
>>> x[:, :m] = k
tensor([[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0., 0., 0., 0., 0.]])
用于torch.randperm
获取逐行列排列:
>>> perm = torch.stack([torch.randperm(10) for _ in range(len(x))])
tensor([[8, 0, 3, 2, 1, 6, 9, 4, 5, 7],
[5, 7, 1, 4, 8, 0, 6, 9, 2, 3],
[2, 1, 9, 7, 0, 8, 6, 3, 5, 4],
[1, 3, 5, 8, 7, 6, 9, 4, 2, 0],
[7, 6, 0, 5, 2, 9, 1, 8, 4, 3],
[5, 0, 6, 8, 1, 9, 2, 4, 3, 7],
[4, 0, 6, 5, 8, 1, 3, 7, 2, 9],
[5, 3, 4, 9, 0, 1, 7, 6, 8, 2],
[5, 7, 9, 3, 2, 6, 8, 0, 4, 1],
[2, 7, 4, 6, 3, 0, 9, 8, 5, 1]])
然后使用torch.gather
索引张x
量perm
:
>>> x.gather(dim=0, index=perm)
tensor([[0., 1., 0., 1., 1., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 1., 0., 0., 1., 0.],
[1., 1., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
[0., 0., 1., 0., 1., 0., 1., 0., 0., 0.],
[0., 1., 0., 0., 1., 0., 1., 0., 0., 0.],
[0., 1., 0., 0., 0., 1., 0., 0., 1., 0.],
[0., 0., 0., 0., 1., 1., 0., 0., 0., 1.],
[0., 0., 0., 0., 1., 0., 0., 1., 0., 1.],
[1., 0., 0., 0., 0., 1., 0., 0., 0., 1.]])
或者,您可以torch.scatter
直接使用value
关键字参数:
>>> torch.zeros(n, n).scatter(dim=0, index=perm, value=1)
tensor([[0., 1., 0., 1., 1., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 1., 0., 0., 1., 0.],
[1., 1., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
[0., 0., 1., 0., 1., 0., 1., 0., 0., 0.],
[0., 1., 0., 0., 1., 0., 1., 0., 0., 0.],
[0., 1., 0., 0., 0., 1., 0., 0., 1., 0.],
[0., 0., 0., 0., 1., 1., 0., 0., 0., 1.],
[0., 0., 0., 0., 1., 0., 0., 1., 0., 1.],
[1., 0., 0., 0., 0., 1., 0., 0., 0., 1.]])
如果m
是张量本身,您可以使用torch.arange
和的组合找到解决方法torch.where
:
首先对位置进行编码:
>>> d = torch.arange(n)[None].repeat(n,1)
>>> x = torch.where(d+m>n, 0, 1)
tensor([[1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
像以前一样构造排列:
>>> perm = torch.stack([torch.randperm(10) for _ in range(n)])
tensor([[2, 5, 7, 0, 4, 1, 3, 6, 8, 9],
[7, 4, 9, 5, 6, 0, 3, 1, 2, 8],
[5, 1, 4, 9, 0, 3, 2, 6, 7, 8],
[9, 6, 0, 2, 3, 1, 7, 5, 4, 8],
[3, 5, 4, 6, 0, 7, 9, 8, 2, 1],
[5, 7, 8, 6, 9, 2, 0, 4, 3, 1],
[8, 3, 9, 0, 6, 2, 5, 7, 4, 1],
[2, 9, 4, 3, 7, 8, 1, 0, 6, 5],
[5, 4, 8, 3, 2, 9, 7, 1, 6, 0],
[8, 7, 3, 6, 5, 4, 2, 0, 9, 1]])
然后分散在x
:
>>> x.scatter(dim=0, index=perm, value=1)
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 1],
[1, 1, 1, 0, 0, 1, 1, 1, 0, 1],
[1, 1, 1, 1, 1, 1, 1, 0, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
推荐阅读
- php - 更新列名中包含特殊字符的数据时,数据表编辑器会出错
- cron - 如何使用默认队列在本地查看我的 sidekiq 控制台输出?
- mysql - 从命令行打开mysql的麻烦
- c - Linux kernel function call flow
- angular - CDK 拖放无法正确更改图像的位置
- windows - 在 Windows 上记录 USB 设备
- css - 将规则应用于同一级别中的所有节点,最后一行中的节点除外
- sql-server - 在一个表中查找可能在第二个表中的多个列中的匹配列记录
- r - ggplot R中geom_bar中离散值的自定义颜色
- python - Tensorflow 2 没有完全连接的函数我该如何模拟呢?