python - The difference of loading model parameters between load_state_dict and nn.Parameter in pytorch
问题描述
When I wanna assign part of pre-trained model parameters to another module defined in a new model of PyTorch, I got two different outputs using two different methods.
The Network is defined as follows:
class Net:
def __init__(self):
super(Net, self).__init__()
self.resnet = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)
self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])
self.freeze_model(self.resnet)
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 3),
)
def forward(self, x):
out = self.resnet(x)
out = out.flatten(start_dim=1)
out = self.classifier(out)
return out
What I want is to assign pre-trained parameters to classifier
in the net module. Two different ways were used for this task.
# First way
net.load_state_dict(torch.load('model_CNN_pretrained.ptl'))
# Second way
params = torch.load('model_CNN_pretrained.ptl')
net.classifier[1].weight = nn.Parameter(params['classifier.1.weight'], requires_grad =False)
net.classifier[1].bias = nn.Parameter(params['classifier.1.bias'], requires_grad =False)
net.classifier[3].weight = nn.Parameter(params['classifier.3.weight'], requires_grad =False)
net.classifier[3].bias = nn.Parameter(params['classifier.3.bias'], requires_grad =False)
The parameters were assigned correctly but got two different outputs from the same input data. The first method works correctly, but the second doesn't work well. Could some guys point what the difference of these two methods?
解决方案
最后,我找出问题出在哪里。
在预训练过程中,即使我们将参数的 require_grad 设置为 False,ResNet18 模型的 BatchNorm2d 层中的缓冲区参数也会发生变化。缓冲区参数由model.train()处理后的输入数据计算,在model.eval()后不变。
有一个关于如何冻结BN层的链接。
推荐阅读
- javascript - React DND - 在鼠标移动时获取拖动元素的坐标
- r - 如何将自定义的“dayHour”字符串变量转换为日期格式?
- c++ - 为什么 alignas() 不采用参数包?
- ibm-cloud - 使用 curl 调用 IBM Cloud Functions 会导致“提供的身份验证无效”
- sql-server - 在没有 Windows 身份验证的情况下从 Access VBA 运行 SQL 数据库中的存储过程
- java - 如何修复 com.google.firebase.DatabaseException:无法将 java.util.HashMap 类型的值转换为字符串
- java - Groovy 为每个匹配返回一个数组
- angular - 基于动态值的角度变化图像
- angular - 通过按钮按下Angular 7滚动水平表
- configuration - 多个问题 - 旧的 MTA exim4 - 新的 MTA 后缀 - 必要的重写规则