首页 > 解决方案 > 如何使用 Pytorch 创建最后一层正确编写的自定义 EfficientNet

问题描述

例如,我有一个分类问题来预测 8 个类,我EfficientNetB3在 pytorch from here中使用。但是,我对我的自定义类是否正确编写感到困惑。我想我想剥离预训练模型的最后一层以适应 8 个输出,对吗?我做对了吗?因为当我y_preds = model(images)在 my中打印时DataLoader,它似乎给了我1536预测。这是预期的行为吗?

!pip install geffnet 
import geffnet

class EfficientNet(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.model = geffnet.create_model(config.effnet, pretrained=True)
        n_features = self.model.classifier.in_features
        # does the name fc matter?
        self.fc = nn.Linear(n_features, config.num_classes)
        self.model.classifier = nn.Identity()
        
    def extract(self, x):
        x = self.model(x)
        return x

    def forward(self, x):
        x = self.extract(x).squeeze(-1).squeeze(-1)
        return x
    
model = EfficientNet(config=config)
if torch.cuda.is_available():
    model.cuda()

打印示例代码y_pred

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for step, (images, labels) in enumerate(sample_loader):
    images = images.to(device)
    labels = labels.to(device)
    batch_size = images.shape[0]        
    y_preds = model(images)
    print('The predictions of the 4 images is as follows\n', y_preds)
    break

标签: machine-learningdeep-learningpytorchclassificationconv-neural-network

解决方案


你甚至没有self.fc在前向传球中使用。

要么将其介绍为:

def forward(self, x):
    ....
    x = extract(x)...
    x = fc(x)
    return x

或者您可以简单地替换名为分类器的层(这样您就不需要身份层):

self.model.classifier = nn.Linear(n_features, config.num_classes)

此外,这里config.num_classes应该是 8。


推荐阅读