首页 > 解决方案 > 如何在没有for循环的情况下进行元素比较和

问题描述

for 循环使我的程序非常慢。我会使用 np.sum(target==output) 但我需要输出中每一行的 argmax 值。我怎样才能加快速度?输出是张量数据类型

for i, x in enumerate(target):
  if target[i] == torch.argmax(output[i]):
    correct_class += 1

标签: pythonnumpypytorch

解决方案


np.argmax您可以使用'参数对上述内容进行矢量化axis,以获得跨行的最大值的索引:

(target==np.argmax(output, axis=1)).sum()

例如:

output = np.random.choice([0,1],(4,2))
print(output)
array([[1, 1],
       [0, 1],
       [0, 1],
       [0, 1]])
target = np.array([[0,1,0,1]])
(target==np.argmax(output, axis=1)).sum()
# 3

推荐阅读