首页 > 解决方案 > 获取 PyTorch 中特定索引处的值

问题描述

我有一个大小为 5 的地面实况标签数组。

y=tensor([958,  85, 244, 182, 294])

我有形状分数数组的输出:[5,1000]

scores   =  tensor([[ 1.0406,  1.1808,  4.4227,  ...,  4.6864,  8.0145,  5.2128],
        [ 6.9101,  4.6083,  6.9259,  ...,  9.7415,  9.6305,  9.3974],
        [ 7.6097,  4.0396,  4.4560,  ...,  3.4892, 11.6411, 2],
        [ 1.0693,  4.6295,  5.3638,  ..., 10.9041, 10.8380,  9.2077],
        [ 1.7085,  1.4938,  8.6876,  ..., 15.1423,  9.6055,  9.8920]],
       grad_fn=<ViewBackward>)

我想要基于 y 的相应索引的分数数组中的值。所以对于 y[0],也就是 958,我想要从分数 [1] 中获得相应的值,位置 958。

我可以使用一些直接的 Pytorch 功能吗?

标签: deep-learningpytorch

解决方案


是的,你可以通过使用你的y数组作为索引来做到这一点:

scores[torch.arange(5), y]

推荐阅读