首页 > 解决方案 > ResNet-101 FeatureMap 形状

问题描述

我对 CNN 很陌生,在学习它时遇到了很多麻烦。

我正在尝试使用 ResNet-101 提取 CNN 特征图,我希望获得 2048、14*14 的形状。为了获得特征图,我删除了 ResNet-101 模型的最后一层并调整了自适应平均池。所以我得到torch.Size([1, 2048, 1, 1])了输出的形状。

但我想得到的torch.Size([1, 2048, 14, 14])不是torch.Size([1, 2048, 1, 1]).

任何人都可以帮我得到结果吗?谢谢。

#load resnet101 model and remove the last layer
model = torch.hub.load('pytorch/vision:v0.5.0', 'resnet101', pretrained=True)
model = torch.nn.Sequential(*(list(model.children())[:-1]))


#extract feature map from an image and print the size of the feature map
from PIL import Image
import matplotlib.pylab as plt
from torchvision import transforms

filename = 'KM_0000000009.jpg'
input_image = Image.open(filename)

preprocess = transforms.Compose([
    transforms.Resize((244,244)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_tensor = preprocess(input_image)

input_tensor = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

with torch.no_grad():
    output = model(input_tensor)

print(output.size()) #torch.Size([1, 2048, 1, 1])

标签: deep-learningobject-detectionresnetconv-neural-network

解决方案


你离你想要的只有一步之遥。

首先要做的事——你应该总是检查模块的源代码(位于这里的 ResNet)。它可能有一些功能操作(例如从torch.nn.functional模块),所以它可能不能直接转移到torch.nn.Seqential,幸运的是它是在 ResNet101 案例中。

其次,特征图取决于输入的大小,对于标准的类似 ImageNet 的图像大小([3, 224, 224],请注意您的图像大小不同)没有具有 shape 的层[2048, 14, 14],但是[2048, 7, 7]or [1024, 14, 14])。

第三,没有必要使用torch.hub ResNet101,因为它torchvision无论如何都使用引擎盖下的模型。

考虑到所有这些:

import torch
import torchvision

# load resnet101 model and remove the last layer
model = torchvision.models.resnet101()
model = torch.nn.Sequential(*(list(model.children())[:-3]))

# image-like
image = torch.randn(1, 3, 224, 224)

with torch.no_grad():
    output = model(image)

print(output.size())  # torch.Size([1, 1024, 14, 14])

如果您想[2048, 7, 7]使用[:-2]而不是[:-3]. 此外,您可以在下面注意到特征图大小如何随图像形状变化:

model = torch.nn.Sequential(*(list(model.children())[:-2]))  
# Image twice as big -> twice as big height and width of features!
image = torch.randn(1, 3, 448, 448)

with torch.no_grad():
    output = model(image)

print(output.size())  # torch.Size([1, 2048, 14, 14])

推荐阅读