首页 > 解决方案 > 如何通过使用 pytorch 阅读我学习的权重的 .ckpt 文件来使用 resnet

问题描述

在 pytorch 中,如何编写加载我的 .ckpt 文件的代码,而不是

model = torchvision.models.resnet50(pretrained=True)

这是我在下面的尝试

model = torchvision.models.resnet50(pretrained=False)

PATH = "/content/drive/MyDrive/Colab Notebooks/mlearning2/multi_logs/resnet_2/version_0/checkpoints/epoch=1-step=2543.ckpt"

model.load_state_dict(torch.load(PATH, map_location=torch.device('cpu')))

但它无法工作,并出现以下错误。

RuntimeError: Error(s) in loading state_dict for ResNet:
    Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1...
    Unexpected key(s) in state_dict: "epoch", "global_step", "pytorch-lightning_version", "state_dict", "callbacks", "optimizer_states", "lr_schedulers", "hparams_name", "hyper_parameters". 

我该怎么做?

@Shai 我试着跑

model.load_state_dict(torch.load(PATH, map_location=torch.device('cpu'))['state_dict'])

但是出现以下错误。 在此处输入图像描述

标签: pytorchconv-neural-networkloadvisualizationresnet

解决方案


您保存的检查点不仅包含模型训练权重的快照,还包含有关训练状态的一些其他有用信息(例如,优化器的状态等)。

尝试仅选择已保存检查点的相关部分:

model.load_state_dict(torch.load(PATH, map_location=torch.device('cpu'))['state_dict'])

更新
根据您所做的修改和收到的新错误,保存的模型似乎是model.backbone = torchvision.models.resnet50().

您需要以model与培训期间相同的方式实例化您的。


推荐阅读