deep-learning - ResNet-101 FeatureMap 形状
问题描述
我对 CNN 很陌生,在学习它时遇到了很多麻烦。
我正在尝试使用 ResNet-101 提取 CNN 特征图,我希望获得 2048、14*14 的形状。为了获得特征图,我删除了 ResNet-101 模型的最后一层并调整了自适应平均池。所以我得到torch.Size([1, 2048, 1, 1])
了输出的形状。
但我想得到的torch.Size([1, 2048, 14, 14])
不是torch.Size([1, 2048, 1, 1])
.
任何人都可以帮我得到结果吗?谢谢。
#load resnet101 model and remove the last layer
model = torch.hub.load('pytorch/vision:v0.5.0', 'resnet101', pretrained=True)
model = torch.nn.Sequential(*(list(model.children())[:-1]))
#extract feature map from an image and print the size of the feature map
from PIL import Image
import matplotlib.pylab as plt
from torchvision import transforms
filename = 'KM_0000000009.jpg'
input_image = Image.open(filename)
preprocess = transforms.Compose([
transforms.Resize((244,244)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_tensor = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
with torch.no_grad():
output = model(input_tensor)
print(output.size()) #torch.Size([1, 2048, 1, 1])
解决方案
你离你想要的只有一步之遥。
首先要做的事——你应该总是检查模块的源代码(位于这里的 ResNet)。它可能有一些功能操作(例如从torch.nn.functional
模块),所以它可能不能直接转移到torch.nn.Seqential
,幸运的是它是在 ResNet101 案例中。
其次,特征图取决于输入的大小,对于标准的类似 ImageNet 的图像大小([3, 224, 224]
,请注意您的图像大小不同)没有具有 shape 的层[2048, 14, 14]
,但是[2048, 7, 7]
or [1024, 14, 14]
)。
第三,没有必要使用torch.hub
ResNet101,因为它torchvision
无论如何都使用引擎盖下的模型。
考虑到所有这些:
import torch
import torchvision
# load resnet101 model and remove the last layer
model = torchvision.models.resnet101()
model = torch.nn.Sequential(*(list(model.children())[:-3]))
# image-like
image = torch.randn(1, 3, 224, 224)
with torch.no_grad():
output = model(image)
print(output.size()) # torch.Size([1, 1024, 14, 14])
如果您想[2048, 7, 7]
使用[:-2]
而不是[:-3]
. 此外,您可以在下面注意到特征图大小如何随图像形状变化:
model = torch.nn.Sequential(*(list(model.children())[:-2]))
# Image twice as big -> twice as big height and width of features!
image = torch.randn(1, 3, 448, 448)
with torch.no_grad():
output = model(image)
print(output.size()) # torch.Size([1, 2048, 14, 14])
推荐阅读
- html - 我怎样才能在这里使用另一个 css 库颜色而不是 Material Design?
- c++ - 从模板化函数中调用模板化类的模板化成员,带有显式参数
- visual-c++ - 我正在尝试在控制台中间打印我的正方形
- reactjs - 模块构建失败:TypeError:无法读取未定义的属性“babel”
- arrays - 在 Unity 中为按钮数组中的每个按钮分配整数值
- c++ - C++ 在循环中生成新随机数的问题/出现提示时
- visual-studio-code - Vscode 中的 dart 不支持全局评估
- c++ - 覆盖错误和更宽松的抛出说明符错误
- c# - HttpClient 非常慢 - Xamarin Forms
- jekyll - 列出两个类别中的所有帖子(Jekyll / Liquid)