python-3.x - PyTorch 分类器输出什么?
问题描述
所以我是深度学习的新手,开始学习 PyTorch。我创建了一个具有以下结构的分类器模型。
class model(nn.Module):
def __init__(self):
super(model, self).__init__()
resnet = models.resnet34(pretrained=True)
layers = list(resnet.children())[:8]
self.features1 = nn.Sequential(*layers[:6])
self.features2 = nn.Sequential(*layers[6:])
self.classifier = nn.Sequential(nn.BatchNorm1d(512), nn.Linear(512, 3))
def forward(self, x):
x = self.features1(x)
x = self.features2(x)
x = F.relu(x)
x = nn.AdaptiveAvgPool2d((1,1))(x)
x = x.view(x.shape[0], -1)
return self.classifier(x)
所以基本上我想在三件事{0,1,2}之间进行分类。在评估时,我传递了图像,它返回了一个具有三个值的张量,如下所示
(tensor([[-0.1526, 1.3511, -1.0384]], device='cuda:0', grad_fn=<AddmmBackward>)
所以我的问题是这三个数字是什么?他们是概率吗?
PS如果我问得太傻了,请原谅我。
解决方案
所以在训练之后你想要做的是应用softmax
到输出张量来提取每个类的概率,然后你选择最大值(最高概率)。
在你的情况下:
prob = torch.nn.functional.softmax(model(x), dim=1)
_, pred_class = torch.max(prob, dim=1)
推荐阅读
- clickhouse - ClickHouse 分布式表和 insert_quorum
- python - 如何解决这个问题:pyreportjasper: ImportError: DLL load failed while importing jpy: The specified module could not be found
- java - 某些移动设备上的身份验证问题
- c++ - 如何修复 curl 链接器错误
- javascript - 格式化和解析输入值
- reactjs - TypeScript 错误类型“{}”缺少类型中的以下属性
- laravel - 在 laravel 中将当前类别设置为活动
- python - 为什么我的函数在 python 中没有打印
- reactjs - useFormState 必须在 a 内部使用
- matlab - Bootstrp在MATLAB中使用矩阵而不是向量?