python - 用于多类分类任务的多层感知器
问题描述
假设我有一个 MLP,它使用 ReLU 作为激活函数和CrossEntropyLoss
损失函数来对具有 3 个特征的样本进行分类,这些特征属于 10 个类之一:我将如何实现它?目标值以 0 到 9 的数字给出。使用时CrossEntropyLoss
,目标值必须是简单的数字,而不是一个热向量。但是当试图将 MLP 的结果转换为单个数字时,我得到一个索引错误。
MLP的标准实现:
class MLP(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MLP, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(self.hidden_size, self.output_size)
self.softmax = torch.nn.Softmax()
def forward(self, x):
hidden = self.fc1(x)
relu = self.relu(hidden)
output = self.fc2(relu)
output = self.softmax(output)
return output
以及给我一个错误的执行:
mlp_model = MLP(3, 10, 10)
criterion = torch.nn.CrossEntropyLoss()
mlp_model.train()
epoch = 20
for epoch in range(epoch):
y_pred = mlp_model(x_train)
y_scalar = torch.argmax(y_pred, dim=1)
loss = criterion(y_scalar, y_train) <-------------- error
loss.backward()
mlp_model.eval()
y_pred = mlp_model(x_test)
y_scalar = torch.argmax(y_pred, dim=1)
test_loss = criterion(y_scalar, y_test)
print('Test loss after Training' , test_loss.item())
y_pred_list = y_pred.tolist()
y_test_list = y_test.tolist()
from sklearn.metrics import accuracy_score
accuracy = accuracy_score(y_test_list, y_pred_list)
错误:IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
y_scalar 和 y_train 的输出:
tensor([1, 3, 3, 3, 1, 1, 1, 3, 3, 1, 3, 1, 1, 3, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3,
1, 3, 1, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 3, 3,
3, 3, 3, 3, 1, 1, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 1, 3, 1, 1,
1, 3, 3, 1, 1, 1, 3, 3, 3, 1, 3, 3, 1, 3, 3, 3, 3, 3, 1, 1, 1, 3, 3, 3,
3, 1, 3, 1, 3, 3, 3, 1, 1, 1, 3, 1, 1, 3, 3, 1, 1, 1, 1, 3, 3, 1, 3, 3,
1, 3, 1, 1, 3, 3, 1, 3, 3, 3, 1, 3, 1, 3, 3, 1, 3, 1, 1, 3, 3, 1, 1, 1,
1, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 1, 1, 1, 3, 3, 1, 3, 3, 3, 3, 1, 3, 1,
1, 3, 3, 1, 1, 1, 3, 3, 3, 1, 3, 1, 3, 1, 1, 1, 3, 3, 1, 3, 3, 1, 3, 3,
3, 3, 3, 3, 3, 1, 3, 1, 1, 3, 1, 3, 3, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3, 3,
3, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 3, 1, 3, 3, 3,
1, 3, 1, 3, 1, 3, 3, 3, 1, 1, 3, 1, 3, 1, 1, 1, 3, 3, 3, 1, 3, 1, 3, 1,
1, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 3, 3, 1, 3, 3, 3, 1, 3, 1, 3,
3, 1, 3, 3, 3, 3, 3, 3, 1, 3, 1, 3, 1, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3, 3,
3, 3, 3, 3, 3, 1, 1, 3, 3, 1, 3, 3, 3, 3, 1, 1, 3, 1, 1, 3, 3, 3, 1, 3,
1, 1, 1, 3, 1, 1, 3, 3, 3, 3, 1, 1, 3, 3, 3, 3, 1, 1, 1, 3, 3, 3, 3, 1,
3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 3, 1, 3, 1, 3, 1, 1, 1, 1, 1, 3, 1,
3, 1, 1, 3, 3, 1, 3, 3, 3, 3, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3, 1, 3, 3, 1,
1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 1, 1, 1, 3, 3, 1,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 1, 3, 3, 3, 3, 3, 1, 3, 1,
3, 1, 3, 1, 1, 3, 3, 1, 3, 3, 1, 3, 1, 3, 1, 3, 3, 3, 3, 3, 3, 1, 1, 3,
1, 3, 3, 1, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 1, 1, 3, 3, 1, 3, 1, 3, 3,
1, 3, 3, 3, 3, 1, 3, 1, 1, 1, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1,
3, 1, 3, 3, 1, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3,
1, 1, 3, 3, 3, 3, 1, 1, 3, 3, 1, 1, 1, 3, 3, 3, 1, 3, 1, 1, 3, 3, 3, 3,
3, 3, 3, 3, 1, 3, 3, 1, 1, 3, 3, 3, 1, 1, 1, 3, 3, 3, 1, 1, 1, 3, 3, 1,
3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 3, 3, 3, 3, 1, 3, 3, 1, 1,
3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3])
tensor([3., 4., 4., 0., 3., 2., 0., 3., 3., 2., 0., 0., 4., 3., 3., 3., 2., 3.,
1., 3., 5., 3., 4., 6., 3., 3., 6., 3., 2., 4., 3., 6., 0., 4., 2., 0.,
1., 5., 4., 4., 3., 6., 6., 4., 3., 3., 2., 5., 3., 4., 5., 3., 0., 2.,
1., 4., 6., 3., 2., 2., 0., 0., 0., 4., 2., 0., 4., 5., 2., 6., 5., 2.,
2., 2., 0., 4., 5., 6., 4., 0., 0., 0., 4., 2., 4., 1., 4., 6., 0., 4.,
2., 4., 6., 6., 0., 0., 6., 5., 0., 6., 0., 2., 1., 1., 1., 2., 6., 5.,
6., 1., 2., 2., 1., 5., 5., 5., 6., 5., 6., 5., 5., 1., 6., 6., 1., 5.,
1., 6., 5., 5., 5., 1., 5., 1., 1., 1., 1., 1., 1., 1., 4., 3., 0., 3.,
6., 6., 0., 3., 4., 0., 3., 4., 4., 1., 2., 2., 2., 3., 3., 3., 3., 0.,
4., 5., 0., 3., 4., 3., 3., 3., 2., 3., 3., 2., 2., 6., 1., 4., 3., 3.,
3., 6., 3., 3., 3., 3., 0., 4., 2., 2., 6., 5., 3., 5., 4., 0., 4., 3.,
4., 4., 3., 3., 2., 4., 0., 3., 2., 3., 3., 4., 4., 0., 3., 6., 0., 3.,
3., 4., 3., 3., 5., 2., 3., 2., 4., 1., 3., 2., 2., 3., 3., 3., 3., 5.,
1., 3., 1., 3., 5., 0., 3., 5., 0., 4., 2., 4., 2., 4., 4., 5., 4., 3.,
5., 3., 3., 4., 3., 0., 4., 5., 0., 3., 6., 2., 5., 5., 5., 3., 2., 3.,
0., 4., 5., 3., 0., 4., 0., 3., 3., 0., 0., 3., 5., 4., 4., 3., 4., 3.,
3., 2., 2., 3., 0., 3., 1., 3., 2., 3., 3., 4., 5., 2., 1., 1., 0., 0.,
1., 6., 1., 3., 3., 3., 2., 3., 3., 0., 3., 4., 1., 3., 4., 3., 2., 0.,
0., 4., 2., 3., 2., 1., 4., 6., 3., 2., 0., 3., 3., 2., 3., 4., 4., 2.,
1., 3., 5., 3., 2., 0., 4., 5., 1., 3., 3., 2., 0., 2., 4., 2., 2., 2.,
5., 4., 4., 2., 2., 0., 3., 2., 4., 4., 5., 5., 1., 0., 3., 4., 5., 3.,
4., 5., 3., 4., 3., 3., 1., 4., 3., 3., 5., 2., 3., 2., 5., 5., 4., 3.,
3., 3., 3., 1., 5., 3., 3., 2., 6., 0., 1., 3., 0., 1., 5., 3., 6., 3.,
6., 0., 3., 3., 3., 5., 4., 3., 4., 0., 5., 2., 1., 2., 4., 4., 4., 4.,
3., 3., 0., 4., 3., 0., 5., 2., 0., 5., 4., 4., 4., 3., 0., 6., 5., 2.,
4., 5., 1., 3., 5., 3., 0., 3., 5., 1., 1., 0., 3., 4., 2., 6., 2., 0.,
5., 3., 4., 6., 5., 3., 5., 0., 1., 3., 0., 5., 2., 2., 3., 5., 1., 0.,
3., 1., 4., 2., 5., 6., 4., 2., 2., 6., 0., 0., 4., 6., 3., 2., 0., 3.,
6., 1., 6., 3., 1., 3., 3., 3., 3., 2., 5., 4., 5., 5., 3., 1., 3., 3.,
4., 4., 2., 0., 2., 0., 5., 4., 0., 0., 3., 2., 2., 2., 2., 6., 4., 6.,
5., 5., 1., 0., 0., 4., 3., 3., 1., 3., 6., 6., 2., 3., 3., 3., 1., 2.,
2., 5., 4., 3., 2., 1., 2., 2., 3., 2., 3., 2., 3., 3., 0., 5., 3., 3.,
3., 4., 5., 3., 2., 1., 4., 4., 4., 4., 0., 5., 4., 1., 3., 0., 3., 4.,
6., 3., 6., 3., 3., 3., 6., 3., 4., 3., 6., 3., 0., 3., 1., 2., 5., 6.,
5., 2., 0., 2., 2., 3., 3., 0., 3., 5., 3., 4., 0., 3., 2., 4., 5., 2.,
3., 2., 2., 3., 5., 2., 0., 3., 4., 3.])```
解决方案
正如评论中提到的那样,模型内部不需要softmax,因为nn.CrossEntropyLoss
它包含它。此外,损失的计算是在 argmax 之前完成的。还要注意模型的输入和输出的形状。请参考以下更新。
import torch
class MLP(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MLP, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(self.hidden_size, self.output_size)
#self.softmax = torch.nn.Softmax()
def forward(self, x):
hidden = self.fc1(x)
relu = self.relu(hidden)
output = self.fc2(relu)
#output = self.softmax(output)
return output
mlp_model = MLP(3, 10, 10)
criterion = torch.nn.CrossEntropyLoss()
mlp_model.train()
epoch = 20
x_train = torch.randn(100, 3) # random 100 inputs of shape (100, 3)
y_train = torch.randint(low=0, high=10, size=(100,)) # random 100 ground truths of shape (100,)
for epoch in range(epoch):
y_pred = mlp_model(x_train)
y_scalar = torch.argmax(y_pred, dim=1)
#loss = criterion(y_scalar, y_train)# <-------------- error
loss = criterion(y_pred, y_train) # loss calculated before argmax
loss.backward().....
推荐阅读
- scala - “EntityStreamException:实体流截断”是否仅仅意味着客户端超时?
- html - 您可以在不嵌入的情况下在您的网站上显示 youtube 视频吗?
- mysql - 如何测试mysql插入方法
- java - 检查java中map的每个键作为空检查
- spring-boot - Spring Data JPA:java.sql.SQLException:找不到列“id”
- dotnetnuke - 从 DNN 9.0.1 升级到 DNN 9.0.2 时的 UpgradeWizard.aspx 错误
- c# - 启动在另一个项目时Controller找不到视图
- python - 可以使用 gensim Doc2Vec 将新文档与训练模型进行比较吗?
- sql - 工作日聚合的星期函数
- ckeditor - ckeditor balloonpanel在滚动时没有保持附着在元素上