python - CNN特征提取
问题描述
class ResNet(nn.Module):
def __init__(self, output_features, fine_tuning=False):
super(ResNet, self).__init__()
self.resnet152 = tv.models.resnet152(pretrained=True)
#freezing the feature extraction layers
for param in self.resnet152.parameters():
param.requires_grad = fine_tuning
#self.features = self.resnet152.features
self.num_fts = 512
self.output_features = output_features
# Linear layer goes from 512 to 1024
self.classifier = nn.Linear(self.num_fts, self.output_features)
nn.init.xavier_uniform_(self.classifier.weight)
self.tanh = nn.Tanh()
def forward(self, x):
h = self.resnet152(x)
print('h: ',h.shape)
return h
image_model_resnet152=ResNet(output_features=10).to(device)
image_model_resnet152
在这里,打印后image_model_resnet152
,我得到:
(avgpool): Linear(in_features=2048)
在这里,和 和有什么区别(classifier): Linear(in_features=512)
?
我正在实现一个图像字幕模型,那么in_features
我应该为图像取哪个?
解决方案
ResNet 不像 VGG 那样简单:它不是一个顺序模型,即forward
在 的定义中存在一些特定于模型的逻辑torchvision.models.resnet152
,例如,CNN 和分类器之间的特征扁平化。你可以看看它的源代码。
在这种情况下,最简单的做法是在 CNN: 的最后一层添加一个钩子,并将该layer4
层的结果记录在外部dict中。这是用register_forward_hook
.
定义钩子:
out = {}
def result(module, input, output):
out['layer4'] = output
将钩子连接到子模块上resnet.layer4
:
>>> x = torch.rand(1,3,224,224)
>>> resnet = torchvision.models.resnet152()
>>> resnet.layer4.register_forward_hook(result)
推断后,您将可以访问以下结果out
:
>>> resnet(x)
>>> out['layer4']
(1, 2048, 7, 7)
您可以查看我的另一个关于更深入地使用前钩的答案。
一个可能的实现是:
class NN(nn.Module):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet152()
self.resnet.layer4.register_forward_hook(result)
self.out = {}
@staticmethod
def result(module, input, output):
out['layer4'] = output
def forward(self, x):
x = self.resnet(x)
return out['layer4']
然后,您可以为您的自定义分类器定义附加层并在内部调用它们forward
。
推荐阅读
- docker - 多对接开发环境的最佳实践
- python - 将 3d 数组存储在 pandas 数据框列中
- aws-secrets-manager - 无法使用 Secrets Manager 密钥注册 AWS Batch 作业定义
- reactjs - React 功能组件中按钮 onClick 上的回调函数
- javascript - javascript场景中的“this”关键字是什么意思?
- python - 如何删除大于五的奇数,如何将所有偶数减半?
- reactjs - Typescript 在使用 create react app 进行编译时不会出错
- html - Mat-option appears at the bottom of the page instead of appearing under mat-select
- reactjs - 使用 SWR 单击时向 API 路由发出请求时出现过多的重新渲染错误
- react-native - React Native 无法识别导入