首页 > 解决方案 > 我尝试使用 pytorch children() 将 resnet 分为两部分,但它不起作用

问题描述

这是一个简单的例子。我试图将网络(Resnet50)分为两部分:headtail使用children. 从概念上讲,这应该有效,但事实并非如此。为什么?

import torch
import torch.nn as nn
from torchvision.models import resnet50

head = nn.Sequential(*list(resnet.children())[:-2])
tail = nn.Sequential(*list(resnet.children())[-2:])
x = torch.zeros(1, 3, 160, 160)

resnet(x).shape      # torch.Size([1, 1000])
head(x).shape        # torch.Size([1, 2048, 5, 5])
tail(head(x)).shape  # Error: RuntimeError: size mismatch, m1: [2048 x 1], m2: [2048 x 1000] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:136

对于信息,尾巴不过是

Sequential(
  (0): AdaptiveAvgPool2d(output_size=(1, 1))
  (1): Linear(in_features=2048, out_features=1000, bias=True)
)

所以我实际上知道,如果我能做到这一点。但是,为什么重塑功能 ( view) 不在孩子们身上呢?

pool =resnet._modules['avgpool']
fc = resnet._modules['fc']
fc(pool(head(x)).view(1, -1))

标签: pytorch

解决方案


您要做的是将特征提取器与分类器分开。

  • 我应该立即指出的是,Resnet不是一个顺序模型(顾名思义 -残差网络- 它作为残差)!

    因此,将其编译为 ann.Sequential将不准确。.children()模型定义、排序显示的层与该模型功能的实际底层实现之间存在差异forward


  • 您使用的展平并未在所有模型view(1, -1)中注册为图层。torchvision.models.resnet*相反,它在定义中的这一行forward执行:

    x = torch.flatten(x, 1)
    

    他们可以将其注册为 as 中的一个层,以在__init__as实现中self.flatten = nn.Flatten()使用。forwardx = self.flatten(x)

    即便如此,与(参见第一点)fc(pool(head(x)).view(1, -1))完全不同。resnet(x)


推荐阅读