python - PyTorch - 另一个张量中对应值的索引
问题描述
我有一个张量,我只想复制一些值(按列)。相同的值在另一个张量中,但顺序是随机的。我想要的是来自tensor2
的值的列索引tensor1
。这是一个例子:
copy_ind = torch.tensor([0, 1, 3], dtype=torch.long)
tensor1 = torch.tensor([[4, 6, 5, 1, 8],[10, 0, 8, 2, 1]])
temp = torch.index_select(tensor1, 1, copy_ind) # values to copy
tensor2 = torch.tensor([[1, 4, 5, 6, 8],[2, 10, 8, 0, 1]], dtype=torch.long)
_, t_ind = torch.sort(temp[0], dim=0)
t2_ind = copy_ind[t_ind] # indices of tensor2
输出应该是:
t2_ind = [1, 3, 0]
这是另一个示例,我想根据以下方式获取张量的值c1_new
:
c1 = torch.tensor([[6, 7, 7, 8, 6, 8, 9, 4, 7, 6, 1, 3],[5, 11, 5, 7, 2, 9, 5, 5, 7, 11, 10, 7]], dtype=torch.long)
copy_ind = torch.tensor([1, 2, 3, 5, 7, 8], dtype=torch.long)
c1_new = torch.index_select(c1, 1, copy_ind)
indices = torch.as_tensor([[1, 3, 4, 6, 6, 6, 7, 7, 7, 8, 8, 9], [10, 7, 5, 2, 5, 11, 5, 7, 11, 7, 9, 5]])
values = torch.randn(12)
tensor = torch.sparse.FloatTensor(indices, values, (12, 12))
_, t_ind = torch.sort(c1[0], dim=0)
ind = t_ind[copy_ind] # should be [8, 6, 9, 10, 2, 7]
不幸的是,指数ind
不正确。有人可以帮帮我吗?
解决方案
如果您可以使用 for 循环,您可以使用以下内容:检查您的临时张量的每一列与 tensor2 的列:
编辑:使用torch.prod
跨维度 1 确保两行匹配
[torch.prod((temp.T[i] == tesnor2.T), dim=1).nonzero()[0] for i in range(temp.size(1))]
我的第一个例子的输出是[tensor(1), tensor(3), tensor(0)]
推荐阅读
- java - Android项目和Android APP上的Jar Lib
- php - 从数组中删除重复的 yyyy/mm
- python - Python:如何计算文件中数字的总和?
- angular - 登录后获取更新的 XSRF 令牌和 Cookie:Angular 4
- html - AngularJS ng-if 三元条件属性使用
- regex - Django、url 和表单:在 url 中发送参数有什么意义?
- javascript - Ajax 调用后重定向到另一个页面?
- excel - Excel CustomUI 按钮链接
- r - 在 ggplot2 中设置 scale_colour_gradientn 的值
- ethereum - 使用松露,如何只获取提交交易的 tx 哈希而不等待收据