numpy - torch.where() 可以以等效的广播形式使用吗?
问题描述
我的代码中有以下 for 循环段。嵌套循环正在减慢我的完整执行速度。
for q in range(batchSize):
temp=torch.where((composition_matrix == pred[q]).all(dim=1))[0]
if len(temp)==0:
output[q]=0
else:
output[q]=int(temp[0])
这里,composition_matrix
是[14000,2]
只有正整数作为单元格值的维度 pytorch 张量。pred
两者output
都是[batchSize,2]
三维火炬张量。由于这个 for 循环大大减慢了我的代码,我无法获得与此代码段等效的广播解决方案。
是否存在广播解决方案来消除此 for 循环?
我将不胜感激任何帮助。
一个最小可重现的例子是
import torch
composition_matrix=torch.randint(3, 10, (14000,2))
batchSize=64
pred=torch.randint(3, 10, (batchSize,2))
output=torch.zeros([batchSize])
for q in range(batchSize):
temp=torch.where((composition_matrix == pred[q]).all(dim=1))[0]
if len(temp)==0:
output[q]=0
else:
output[q]=int(temp[0])
解决方案
为简单起见,您首先需要了解操作本质上在做什么。你有两个张量。张量 A 的形状(14000, 2)
和张量 B 的形状(64, 2)
。您要做的操作是:
对于 B 中的每一行 B[i],将 B[i](形状为 (2,))与 A(形状为 (14000, 2))进行比较。如果 B[i] 出现在 A 中,则设置 output[i] =首次出现的索引。
这实际上可以在两行代码中完成(甚至可能是一行):
comp = (composition_matrix[:, None, :] == pred).all(dim=-1)
output = torch.argmax(comp.float(), axis=0)
第一行创建了一个和
comp
的广播比较,一个布尔张量。composition_matrix
pred
(14000, 64)
第二行需要找到“第一个匹配的索引”。这可以通过 argmax 非常简单地完成:它将返回第一个“1”的索引(或者如果所有值都是“0”,则返回第一个索引,即 0)。
(请注意,torch 不支持“bool”张量的 argmax,因此 comp 需要转换为另一种数据类型。)
推荐阅读
- php - laravel 中的图片数组
- http - 正确的 URL 地址但未找到资源的正确 HTTP 代码
- r - 如何使用 tidyverse 将我的数据框分成 10 行?
- xamarin.forms - IOS 中的 Xamarin 表单选择器不会换行
- .net - 如何更改 .Net Framework 的运行时版本
- sql-server - SQL Server 导出向导正在将所有数据更改为 nvarchar
- java - 我将 hibernate-validator 升级到版本 6.1.5,应用程序在 WAS 8.5.5.17 上不起作用
- python - 根据数据框中的数据创建新变量,忽略 NaN
- asp.net-core - 无法从控制器操作刷新页面(ASP.NET MVC 核心)
- reactjs - 子组件中的useSelector重新渲染父组件