首页 > 解决方案 > 在布尔张量掩码中查找峰值点的索引(第一个 True)

问题描述

在 pytorch 中运行 Detectron2 模型后,Detectron2 为我提供了它发现的对象掩码作为(真/假)张量。在图像中发现了 33 个对象,所以我有 torch.Size([33, 683, 1024])。

tensor([[False, False, False,  ..., False, False, False],
    [False, False, False,  ..., False, False, False],
    [False, False, False,  ..., False, False, False],
    ...,
    [False, False, False,  ..., False, False, False],
    [False, False, False,  ..., False, False, False],
    [False, False, False,  ..., False, False, False]], device='cuda:0')

到目前为止这很棒。但我需要这 33 个对象在 y 维度(高度)中的峰值坐标。(假设物体是气球,那么我需要气球的顶部作为(x,y)点)

知道如何尽快获得峰值点坐标,谢谢

标签: pythonpytorch

解决方案


我已经遍历了每个维度并检查了 True 条件是否满足,但需要几分钟才能找到索引

然后我使用了 torch.where 方法,它立即找到了所有满足条件的索引。

for maskCounter in range(masks.shape[0]):
    print((torch.where(masks[maskCounter] == True)[0][0]).item(), (torch.where(masks[maskCounter] == True)[1][0]).item())

推荐阅读