首页 > 解决方案 > 以并行方式在 3 维数组中查找某个值的索引

问题描述

我在我的 pytorch 文件中间创建了一个zero_mask2函数(复制如下)。但是,它太慢了。所以,我正在寻找更好的方法。

首先,让我解释一下下面的函数的要点。

输入

玩具示例

可以说,有一个all_idx_before(9, 10, 11, 12, 6, 7, 8, 14, 2, 3, 4, 5, 18, 19, 15, 16, 17, 13, 1, 20)

并且有idx一点(10, 6, 2, 4, 18)

在这个设置中,我想找到每个值的索引idx。例如,(10)inidx是 中的第二个元素all_idx_before(6)in是...中idx的第五个元素all_idx_before

所以,这一点的潜在输出就像(1, 5, 8, 10, 12), ('second' becomes '1' because of number system in python.)

我在每个批次和每个点都执行此逻辑。所以,我用于迭代。但这太慢了。有没有办法并行执行此操作?我使用 NumPy 格式。但如果它更好,我可以使用张量。

def zero_mask2(idx, all_idx_before, p):

size_batch = len(idx)
size_points = len(idx[0])
size_neighbor = len(idx[0][0])

mask = torch.empty(size_batch, size_points, size_neighbor)

for i in range(size_batch):
 for j in range(size_points):
  for l in range(size_neighbor):
   mask[i][j][l] = np.where(all_idx_before[i][j] == idx[i][j][l])[0][0] 

mask_t = mask <= p # smaller than p or same : true , bigger than p : false
return mask_t

标签: pythonnumpyindexingpytorch

解决方案


您可以比较和使用argmax

tmp = all_idx_before[..., None, :] == idx[..., None]  # compare along additional last dimension
mask = torch.argmax(tmp, dim=-1)

推荐阅读