python - 如何从索引向量 B 中计算从向量 A 中选择的值的平均值?
问题描述
我有一个值向量,例如:
import torch
v = torch.rand(6)
tensor([0.0811, 0.9658, 0.1901, 0.0872, 0.8895, 0.9647])
和一个index
从 中选择值v
:
index = torch.tensor([0, 1, 0, 2, 0, 2])
我想生成一个向量,该向量将计算按索引分组mean
的平均值。v
index
在这个例子中:
mean = torch.tensor([(0.0811 + 0.1901 + 0.8895) / 3, 0.9658, (0.0872 + 0.9647) / 2, 0, 0, 0])
tensor([0.3869, 0.9658, 0.5260, 0.0000, 0.0000, 0.0000])
解决方案
torch.bincount
结合使用和 Tensor.index_add()的一种可能解决方案:
v = torch.tensor([0.0811, 0.9658, 0.1901, 0.0872, 0.8895, 0.9647])
index = torch.tensor([0, 1, 0, 2, 0, 2])
bincount() 获取每个索引使用的总数index
:
bincount = torch.bincount(index, minlength=6)
# --> tensor([3, 1, 2, 0, 0, 0])
index_add()v
按以下给出的顺序添加 from index
:
numerator = torch.zeros(6)
numerator = numerator.index_add(0, index, v)
# --> tensor([1.1607, 0.9658, 1.0520, 0.0000, 0.0000, 0.0000])
将 bincount 中的零替换为 1.0 以防止除以 0 并将 int 转换为 float:
div = bincount.float()
div[bincount == 0] = 1.0
# --> tensor([3., 1., 2., 1., 1., 1.])
mean = num/div
# --> tensor([0.3869, 0.9658, 0.5260, 0.0000, 0.0000, 0.0000])
推荐阅读
- php - php数组迭代视觉上父>子
- android - android Polyline会在一段时间后滞后地图
- mysql - 我在将 mysql 数据库(phpmyadmin)连接到 node.js 时收到错误
- csv - Spark:正在写入的加载文件
- php - Symfony - 控制器编辑动作
- xamarin.forms - IContainerRegistryExtensions 如何将实例注册为单例
- angular - return custom routes with angular and express
- java - how to switch from localhost to the internet in jsf
- android - Android 版本 >= 6.0 上不存在 String#value 字段
- elasticsearch - 麋鹿栈 GPG13