首页 > 解决方案 > torch.where() 可以以等效的广播形式使用吗?

问题描述

我的代码中有以下 for 循环段。嵌套循环正在减慢我的完整执行速度。

for q in range(batchSize):
    temp=torch.where((composition_matrix == pred[q]).all(dim=1))[0]
    if len(temp)==0:
        output[q]=0
    else:
        output[q]=int(temp[0])

这里,composition_matrix[14000,2]只有正整数作为单元格值的维度 pytorch 张量。pred两者output都是[batchSize,2]三维火炬张量。由于这个 for 循环大大减慢了我的代码,我无法获得与此代码段等效的广播解决方案。

是否存在广播解决方案来消除此 for 循环?

我将不胜感激任何帮助。

一个最小可重现的例子是

import torch
composition_matrix=torch.randint(3, 10, (14000,2))
batchSize=64
pred=torch.randint(3, 10, (batchSize,2))
output=torch.zeros([batchSize])

for q in range(batchSize):
    temp=torch.where((composition_matrix == pred[q]).all(dim=1))[0]
    if len(temp)==0:
        output[q]=0
    else:
        output[q]=int(temp[0])

标签: numpypytorchnumpy-ndarrayarray-broadcastingnumpy-slicing

解决方案


为简单起见,您首先需要了解操作本质上在做什么。你有两个张量。张量 A 的形状(14000, 2)和张量 B 的形状(64, 2)。您要做的操作是:

对于 B 中的每一行 B[i],将 B[i](形状为 (2,))与 A(形状为 (14000, 2))进行比较。如果 B[i] 出现在 A 中,则设置 output[i] =首次出现的索引。

这实际上可以在两行代码中完成(甚至可能是一行):

comp = (composition_matrix[:, None, :] == pred).all(dim=-1)
output = torch.argmax(comp.float(), axis=0)
  • 第一行创建了一个和comp的广播比较,一个布尔张量。composition_matrixpred(14000, 64)

  • 第二行需要找到“第一个匹配的索引”。这可以通过 argmax 非常简单地完成:它将返回第一个“1”的索引(或者如果所有值都是“0”,则返回第一个索引,即 0)。

(请注意,torch 不支持“bool”张量的 argmax,因此 comp 需要转换为另一种数据类型。)


推荐阅读