首页 > 解决方案 > 如何为 centernet 提取 DLA34 的特征?

问题描述

  1. centernet 参考文件:https ://arxiv.org/abs/1904.07850
  2. Deep Layer Aggregation 论文:https ://arxiv.org/abs/1707.06484

我试图使用这个 repo 将 DLA34 与 centernet 一起使用:https ://github.com/xingyizhou/CenterNet/blob/master/src/lib/models/networks/pose_dla_dcn.py

我可以像这样使用 centernet 加载高效网络功能:

from efficientnet_pytorch import EfficientNet
base_model = EfficientNet.from_pretrained('efficientnet-b1')
x_center = x[:, :, :, IMG_WIDTH // 8: -IMG_WIDTH // 8]
feats = base_model.extract_features(x_center) 

但是在深层聚合(DLA34)中,extract_features()函数不可用,我是对象检测的新手,我如何从 dla34 和其他网络(如带有 centernet 的densenet)中提取_features?

标签: pythondeep-learningconv-neural-networkobject-detectionimage-segmentation

解决方案


    class CentDla(nn.Module):
    '''Mixture of previous classes'''
    def __init__(self, n_classes):
        super(CentDla, self).__init__()
        self.base_model = dla34(pretrained=True) 
        # https://github.com/osmr/imgclsmob/blob/master/pytorch/pytorchcv/models/dla.py

        # Lateral layers convert resnet outputs to a common feature size
        self.lat8 = nn.Conv2d(128, 256, 1)
        self.lat16 = nn.Conv2d(256, 256, 1)
        self.lat32 = nn.Conv2d(512, 256, 1)
        self.bn8 = nn.BatchNorm2d(256)
        self.bn16 = nn.BatchNorm2d(256)
        self.bn32 = nn.BatchNorm2d(256)


        self.conv0 = double_conv(5, 64)
        self.conv1 = double_conv(64, 128)
        self.conv2 = double_conv(128, 512)
        self.conv3 = double_conv(512, 1024)

        self.mp = nn.MaxPool2d(2)

        self.up1 = up(1282 , 512) #+ 1024
        self.up2 = up(512 + 512, 256)
        self.outc = nn.Conv2d(256, n_classes, 1)


    def forward(self, x):
        batch_size = x.shape[0]
        mesh1 = get_mesh(batch_size, x.shape[2], x.shape[3])
        x0 = torch.cat([x, mesh1], 1)
        x1 = self.mp(self.conv0(x0))
        x2 = self.mp(self.conv1(x1))
        x3 = self.mp(self.conv2(x2))
        x4 = self.mp(self.conv3(x3))

        #feats = self.base_model.extract_features(x)
                # Run frontend network
        feats32 = self.base_model(x)[5]
        #lat8 = F.relu(self.bn8(self.lat8(feats8)))
        #lat16 = F.relu(self.bn16(self.lat16(feats16)))
        lat32 = F.relu(self.bn32(self.lat32(feats32)))

        # Add positional info
        mesh2 = get_mesh(batch_size, lat32.shape[2], lat32.shape[3])
        feats = torch.cat([lat32, mesh2], 1)
        #print(feats.shape)
        #print (x4.shape)
        x = self.up1(feats, x4)
        x = self.up2(x, x3)
        x = self.outc(x)
        return x

# Gets the GPU if there is one, otherwise the cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

n_epochs = 20 #6
n_classes = 8
model = CentDla(n_classes).to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
#optimizer =  RAdam(model.parameters(), lr = 0.001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=max(n_epochs, 10) * len(train_loader) // 3, gamma=0.1)

推荐阅读