首页 > 解决方案 > 获取连续最小比较的索引

问题描述

说如果我有张量

values = torch.tensor([5., 4., 8., 3.])

我想对每两个连续的值取最小值,意思是

min(5., 4.) = 4.
min(8., 3.) = 3.

有没有一种矢量化的方式来做它并且仍然获得最小值的相对索引?意思是我想要的输出是:

min_index = [1, 1]
#min_index[0] == 1 as 4. is the minimum of (5., 4.) and is in index 1 of (5., 4.)
#min_index[1] == 1 as 3. is the minimum of (8., 3.) and is in index 1 of (8., 3.) 

标签: pythonnumpypytorch

解决方案


我认为重塑你的张量会使它更容易。之后torch.min自动返回最小值和索引。

import torch

values = torch.tensor([5., 4., 8., 3.])
values_reshaped = values.reshape(-1,2) # works for any length
minimums, index = torch.min(values_reshaped, axis = -1)
print(minimums) # tensor of the minimum values
print(index) # tensor of indexes

推荐阅读