python - 如何在二维张量中获得相同值的张量索引?
问题描述
如标题中所述,给定一个二维张量,假设:
tensor([
[0, 1, 0, 1], # A
[1, 1, 0, 1], # B
[1, 0, 0, 1], # C
[0, 1, 0, 1], # D
[1, 1, 0, 1], # E
[1, 1, 0, 1] # F
])
这很容易看出“A 和 D”、“B、E 和 F”是两组张量,
具有相同的值(这意味着 A == D 和 B == E == F)。
所以我的问题是:
如何获取这些组的索引?
细节:
输入:上面的张量
输出:(0, 3), (1, 4, 5)
解决方案
使用 PyTorch 函数的解决方案:
import torch
x = torch.tensor([
[0, 1, 0, 1], # A
[1, 1, 0, 1], # B
[1, 0, 0, 1], # C
[0, 1, 0, 1], # D
[1, 1, 0, 1], # E
[1, 1, 0, 1] # F
])
_, inv, counts = torch.unique(x, dim=0, return_inverse=True, return_counts=True)
print([tuple(torch.where(inv == i)[0].tolist()) for i, c, in enumerate(counts) if counts[i] > 1])
# > [(0, 3), (1, 4, 5)]
推荐阅读
- php - 如何将用户 ID 保存在 laravel 中特定用户使用软删除删除的记录中
- python - 如何读取文件夹中的所有图像?
- bash - Bash 解析器在命令行中以什么顺序转义字符和拆分单词/标记?
- swift - 为具有失败响应的字典数组实现 Codable
- javascript - Retyped.chartist 用法不清楚。如何创建基本折线图?
- android - 从 Firestore 获取数据后,方法总是返回 null
- java - 内存分配 Java 异常
- php - 将数据从一个表复制到另一个没有 ID
- java - 如何在 JavaFX 中延迟 10 秒在 TextArea 中附加文本?
- sed - 找到匹配并删除它们,只留下不匹配的