首页 > 解决方案 > 用于多类分类任务的多层感知器

问题描述

假设我有一个 MLP,它使用 ReLU 作为激活函数和CrossEntropyLoss损失函数来对具有 3 个特征的样本进行分类,这些特征属于 10 个类之一:我将如何实现它?目标值以 0 到 9 的数字给出。使用时CrossEntropyLoss,目标值必须是简单的数字,而不是一个热向量。但是当试图将 MLP 的结果转换为单个数字时,我得到一个索引错误。

MLP的标准实现:

class MLP(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__() 
        self.input_size = input_size
        self.hidden_size  = hidden_size
        self.output_size = output_size
        self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(self.hidden_size, self.output_size)
        self.softmax = torch.nn.Softmax()
    
    def forward(self, x):
        hidden = self.fc1(x)
        relu = self.relu(hidden)
        output = self.fc2(relu)
        output = self.softmax(output)
        return output

以及给我一个错误的执行:

mlp_model = MLP(3, 10, 10)
criterion = torch.nn.CrossEntropyLoss()
mlp_model.train()
epoch = 20
for epoch in range(epoch):
    y_pred = mlp_model(x_train)
    y_scalar = torch.argmax(y_pred, dim=1)

    loss = criterion(y_scalar, y_train) <-------------- error

    loss.backward()
mlp_model.eval()
y_pred = mlp_model(x_test)
y_scalar = torch.argmax(y_pred, dim=1)
test_loss = criterion(y_scalar, y_test) 
print('Test loss after Training' , test_loss.item())

y_pred_list = y_pred.tolist()
y_test_list = y_test.tolist()

from sklearn.metrics import accuracy_score
accuracy = accuracy_score(y_test_list, y_pred_list)

错误:IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

y_scalar 和 y_train 的输出:

tensor([1, 3, 3, 3, 1, 1, 1, 3, 3, 1, 3, 1, 1, 3, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3,
        1, 3, 1, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 3, 3,
        3, 3, 3, 3, 1, 1, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 1, 3, 1, 1,
        1, 3, 3, 1, 1, 1, 3, 3, 3, 1, 3, 3, 1, 3, 3, 3, 3, 3, 1, 1, 1, 3, 3, 3,
        3, 1, 3, 1, 3, 3, 3, 1, 1, 1, 3, 1, 1, 3, 3, 1, 1, 1, 1, 3, 3, 1, 3, 3,
        1, 3, 1, 1, 3, 3, 1, 3, 3, 3, 1, 3, 1, 3, 3, 1, 3, 1, 1, 3, 3, 1, 1, 1,
        1, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 1, 1, 1, 3, 3, 1, 3, 3, 3, 3, 1, 3, 1,
        1, 3, 3, 1, 1, 1, 3, 3, 3, 1, 3, 1, 3, 1, 1, 1, 3, 3, 1, 3, 3, 1, 3, 3,
        3, 3, 3, 3, 3, 1, 3, 1, 1, 3, 1, 3, 3, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3, 3,
        3, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 3, 1, 3, 3, 3,
        1, 3, 1, 3, 1, 3, 3, 3, 1, 1, 3, 1, 3, 1, 1, 1, 3, 3, 3, 1, 3, 1, 3, 1,
        1, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 1, 3, 3, 3, 1, 3, 3, 3, 1, 3, 1, 3,
        3, 1, 3, 3, 3, 3, 3, 3, 1, 3, 1, 3, 1, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3, 3,
        3, 3, 3, 3, 3, 1, 1, 3, 3, 1, 3, 3, 3, 3, 1, 1, 3, 1, 1, 3, 3, 3, 1, 3,
        1, 1, 1, 3, 1, 1, 3, 3, 3, 3, 1, 1, 3, 3, 3, 3, 1, 1, 1, 3, 3, 3, 3, 1,
        3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 3, 1, 3, 1, 3, 1, 1, 1, 1, 1, 3, 1,
        3, 1, 1, 3, 3, 1, 3, 3, 3, 3, 1, 1, 3, 3, 3, 3, 3, 3, 1, 3, 1, 3, 3, 1,
        1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 3, 3, 1, 1, 1, 3, 3, 1,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1, 1, 3, 3, 3, 3, 3, 1, 3, 1,
        3, 1, 3, 1, 1, 3, 3, 1, 3, 3, 1, 3, 1, 3, 1, 3, 3, 3, 3, 3, 3, 1, 1, 3,
        1, 3, 3, 1, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 1, 1, 3, 3, 1, 3, 1, 3, 3,
        1, 3, 3, 3, 3, 1, 3, 1, 1, 1, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 1,
        3, 1, 3, 3, 1, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3,
        1, 1, 3, 3, 3, 3, 1, 1, 3, 3, 1, 1, 1, 3, 3, 3, 1, 3, 1, 1, 3, 3, 3, 3,
        3, 3, 3, 3, 1, 3, 3, 1, 1, 3, 3, 3, 1, 1, 1, 3, 3, 3, 1, 1, 1, 3, 3, 1,
        3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 3, 3, 3, 3, 1, 3, 3, 1, 1,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3])
tensor([3., 4., 4., 0., 3., 2., 0., 3., 3., 2., 0., 0., 4., 3., 3., 3., 2., 3.,
        1., 3., 5., 3., 4., 6., 3., 3., 6., 3., 2., 4., 3., 6., 0., 4., 2., 0.,
        1., 5., 4., 4., 3., 6., 6., 4., 3., 3., 2., 5., 3., 4., 5., 3., 0., 2.,
        1., 4., 6., 3., 2., 2., 0., 0., 0., 4., 2., 0., 4., 5., 2., 6., 5., 2.,
        2., 2., 0., 4., 5., 6., 4., 0., 0., 0., 4., 2., 4., 1., 4., 6., 0., 4.,
        2., 4., 6., 6., 0., 0., 6., 5., 0., 6., 0., 2., 1., 1., 1., 2., 6., 5.,
        6., 1., 2., 2., 1., 5., 5., 5., 6., 5., 6., 5., 5., 1., 6., 6., 1., 5.,
        1., 6., 5., 5., 5., 1., 5., 1., 1., 1., 1., 1., 1., 1., 4., 3., 0., 3.,
        6., 6., 0., 3., 4., 0., 3., 4., 4., 1., 2., 2., 2., 3., 3., 3., 3., 0.,
        4., 5., 0., 3., 4., 3., 3., 3., 2., 3., 3., 2., 2., 6., 1., 4., 3., 3.,
        3., 6., 3., 3., 3., 3., 0., 4., 2., 2., 6., 5., 3., 5., 4., 0., 4., 3.,
        4., 4., 3., 3., 2., 4., 0., 3., 2., 3., 3., 4., 4., 0., 3., 6., 0., 3.,
        3., 4., 3., 3., 5., 2., 3., 2., 4., 1., 3., 2., 2., 3., 3., 3., 3., 5.,
        1., 3., 1., 3., 5., 0., 3., 5., 0., 4., 2., 4., 2., 4., 4., 5., 4., 3.,
        5., 3., 3., 4., 3., 0., 4., 5., 0., 3., 6., 2., 5., 5., 5., 3., 2., 3.,
        0., 4., 5., 3., 0., 4., 0., 3., 3., 0., 0., 3., 5., 4., 4., 3., 4., 3.,
        3., 2., 2., 3., 0., 3., 1., 3., 2., 3., 3., 4., 5., 2., 1., 1., 0., 0.,
        1., 6., 1., 3., 3., 3., 2., 3., 3., 0., 3., 4., 1., 3., 4., 3., 2., 0.,
        0., 4., 2., 3., 2., 1., 4., 6., 3., 2., 0., 3., 3., 2., 3., 4., 4., 2.,
        1., 3., 5., 3., 2., 0., 4., 5., 1., 3., 3., 2., 0., 2., 4., 2., 2., 2.,
        5., 4., 4., 2., 2., 0., 3., 2., 4., 4., 5., 5., 1., 0., 3., 4., 5., 3.,
        4., 5., 3., 4., 3., 3., 1., 4., 3., 3., 5., 2., 3., 2., 5., 5., 4., 3.,
        3., 3., 3., 1., 5., 3., 3., 2., 6., 0., 1., 3., 0., 1., 5., 3., 6., 3.,
        6., 0., 3., 3., 3., 5., 4., 3., 4., 0., 5., 2., 1., 2., 4., 4., 4., 4.,
        3., 3., 0., 4., 3., 0., 5., 2., 0., 5., 4., 4., 4., 3., 0., 6., 5., 2.,
        4., 5., 1., 3., 5., 3., 0., 3., 5., 1., 1., 0., 3., 4., 2., 6., 2., 0.,
        5., 3., 4., 6., 5., 3., 5., 0., 1., 3., 0., 5., 2., 2., 3., 5., 1., 0.,
        3., 1., 4., 2., 5., 6., 4., 2., 2., 6., 0., 0., 4., 6., 3., 2., 0., 3.,
        6., 1., 6., 3., 1., 3., 3., 3., 3., 2., 5., 4., 5., 5., 3., 1., 3., 3.,
        4., 4., 2., 0., 2., 0., 5., 4., 0., 0., 3., 2., 2., 2., 2., 6., 4., 6.,
        5., 5., 1., 0., 0., 4., 3., 3., 1., 3., 6., 6., 2., 3., 3., 3., 1., 2.,
        2., 5., 4., 3., 2., 1., 2., 2., 3., 2., 3., 2., 3., 3., 0., 5., 3., 3.,
        3., 4., 5., 3., 2., 1., 4., 4., 4., 4., 0., 5., 4., 1., 3., 0., 3., 4.,
        6., 3., 6., 3., 3., 3., 6., 3., 4., 3., 6., 3., 0., 3., 1., 2., 5., 6.,
        5., 2., 0., 2., 2., 3., 3., 0., 3., 5., 3., 4., 0., 3., 2., 4., 5., 2.,
        3., 2., 2., 3., 5., 2., 0., 3., 4., 3.])```

标签: pythonmachine-learningpytorchmulticlass-classification

解决方案


正如评论中提到的那样,模型内部不需要softmax,因为nn.CrossEntropyLoss它包含它。此外,损失的计算是在 argmax 之前完成的。还要注意模型的输入和输出的形状。请参考以下更新。

import torch
class MLP(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__() 
        self.input_size = input_size
        self.hidden_size  = hidden_size
        self.output_size = output_size
        self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(self.hidden_size, self.output_size)
        #self.softmax = torch.nn.Softmax()
    
    def forward(self, x):
        hidden = self.fc1(x)
        relu = self.relu(hidden)
        output = self.fc2(relu)
        #output = self.softmax(output)
        return output

mlp_model = MLP(3, 10, 10)
criterion = torch.nn.CrossEntropyLoss()
mlp_model.train()
epoch = 20
x_train = torch.randn(100, 3) # random 100 inputs of shape (100, 3)
y_train = torch.randint(low=0, high=10, size=(100,)) # random 100 ground truths of shape (100,)
for epoch in range(epoch):
    y_pred = mlp_model(x_train)
    y_scalar = torch.argmax(y_pred, dim=1)

    #loss = criterion(y_scalar, y_train)# <-------------- error
    loss = criterion(y_pred, y_train) # loss calculated before argmax

    loss.backward().....

推荐阅读