pytorch - 如何通过使用 pytorch 阅读我学习的权重的 .ckpt 文件来使用 resnet
问题描述
在 pytorch 中,如何编写加载我的 .ckpt 文件的代码,而不是
model = torchvision.models.resnet50(pretrained=True)
这是我在下面的尝试
model = torchvision.models.resnet50(pretrained=False)
PATH = "/content/drive/MyDrive/Colab Notebooks/mlearning2/multi_logs/resnet_2/version_0/checkpoints/epoch=1-step=2543.ckpt"
model.load_state_dict(torch.load(PATH, map_location=torch.device('cpu')))
但它无法工作,并出现以下错误。
RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1...
Unexpected key(s) in state_dict: "epoch", "global_step", "pytorch-lightning_version", "state_dict", "callbacks", "optimizer_states", "lr_schedulers", "hparams_name", "hyper_parameters".
我该怎么做?
@Shai 我试着跑
model.load_state_dict(torch.load(PATH, map_location=torch.device('cpu'))['state_dict'])
解决方案
您保存的检查点不仅包含模型训练权重的快照,还包含有关训练状态的一些其他有用信息(例如,优化器的状态等)。
尝试仅选择已保存检查点的相关部分:
model.load_state_dict(torch.load(PATH, map_location=torch.device('cpu'))['state_dict'])
更新
根据您所做的修改和收到的新错误,保存的模型似乎是model.backbone = torchvision.models.resnet50()
.
您需要以model
与培训期间相同的方式实例化您的。
推荐阅读
- laravel - 如何获取从 Laravel 控制器中的复选框输入中获取的数组值?
- android - 如何在android Q中获取相邻单元格信息?
- c# - 如何映射两个不同的接口,以便一个接口值自动更改,另一个应该得到反映
- c++ - Oracle 即时客户端头文件丢失
- docker - 如何强制子域与 NGINX 完全匹配?
- c - 使用文本表进行类别线性转换的 Compu 方法
- r - R for 循环是数字
- plsqldeveloper - PL/SQL Developer 默认文件打开
- json - gunicorn 日志配置 access_log_format
- r - 使用 Plotly 绘制反应图