python - 如何保存模型的训练权重检查点并从 PyTorch 的最后一点继续训练?
问题描述
我正在尝试在一定数量的时期后保存训练模型的检查点权重,并继续使用 PyTorch 从最后一个检查点训练到另一个时期为了实现这一点,我编写了如下脚本
训练模型:
def create_model():
# load model from package
model = smp.Unet(
encoder_name="resnet152", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights='imagenet', # use `imagenet` pre-trained weights for encoder initialization
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=2, # model output channels (number of classes in your dataset)
)
return model
model = create_model()
model.to(device)
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
epochs = 5
for epoch in range(epochs):
print('Epoch: [{}/{}]'.format(epoch+1, epochs))
# train set
pbar = tqdm(train_loader)
model.train()
iou_logger = iouTracker()
for batch in pbar:
# load image and mask into device memory
image = batch['image'].to(device)
mask = batch['mask'].to(device)
# pass images into model
pred = model(image)
# pred = checkpoint['model_state_dict']
# get loss
loss = criteria(pred, mask)
# update the model
optimizer.zero_grad()
loss.backward()
optimizer.step()
# compute and display progress
iou_logger.update(pred, mask)
mIoU = iou_logger.get_mean()
pbar.set_description('Loss: {0:1.4f} | mIoU {1:1.4f}'.format(loss.item(), mIoU))
# development set
pbar = tqdm(development_loader)
model.eval()
iou_logger = iouTracker()
with torch.no_grad():
for batch in pbar:
# load image and mask into device memory
image = batch['image'].to(device)
mask = batch['mask'].to(device)
# pass images into model
pred = model(image)
# get loss
loss = criteria(pred, mask)
# compute and display progress
iou_logger.update(pred, mask)
mIoU = iou_logger.get_mean()
pbar.set_description('Loss: {0:1.4f} | mIoU {1:1.4f}'.format(loss.item(), mIoU))
# save model
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,}, '/content/drive/MyDrive/checkpoint.pt')
由此,我可以将模型检查点文件保存checkpoint.pt
为 5 个时期
为了使用保存的检查点权重文件继续训练,我在下面编写了另一个脚本:
epochs = 5
for epoch in range(epochs):
print('Epoch: [{}/{}]'.format(epoch+1, epochs))
# train set
pbar = tqdm(train_loader)
checkpoint = torch.load( '/content/drive/MyDrive/checkpoint.pt')
print(checkpoint)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.train()
iou_logger = iouTracker()
for batch in pbar:
# load image and mask into device memory
image = batch['image'].to(device)
mask = batch['mask'].to(device)
# pass images into model
pred = model(image)
# pred = checkpoint['model_state_dict']
# get loss
loss = criteria(pred, mask)
# update the model
optimizer.zero_grad()
loss.backward()
optimizer.step()
# compute and display progress
iou_logger.update(pred, mask)
mIoU = iou_logger.get_mean()
pbar.set_description('Loss: {0:1.4f} | mIoU {1:1.4f}'.format(loss.item(), mIoU))
# development set
pbar = tqdm(development_loader)
model.eval()
iou_logger = iouTracker()
with torch.no_grad():
for batch in pbar:
# load image and mask into device memory
image = batch['image'].to(device)
mask = batch['mask'].to(device)
# pass images into model
pred = model(image)
# get loss
loss = criteria(pred, mask)
# compute and display progress
iou_logger.update(pred, mask)
mIoU = iou_logger.get_mean()
pbar.set_description('Loss: {0:1.4f} | mIoU {1:1.4f}'.format(loss.item(), mIoU))
# save model
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,}, 'checkpoint.pt')
这会引发错误:
RuntimeError Traceback (most recent call last)
<ipython-input-31-54f48c10531a> in <module>()
---> 14 model.load_state_dict(checkpoint['model_state_dict'])
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1222 if len(error_msgs) > 0:
1223 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1224 self.__class__.__name__, "\n\t".join(error_msgs)))
1225 return _IncompatibleKeys(missing_keys, unexpected_keys)
1226
RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.encoder.conv1.weight", "module.encoder.bn1.weight", "module.encoder.bn1.bias", "module.encoder.bn1.running_mean", "module.encoder.bn1.running_var", "module.encoder.layer1.0.conv1.weight", "module.encoder.layer1.0.bn1.weight", "module.encoder.layer1.0.bn1.bias", "module.encoder.layer1.0.bn1.running_mean", "module.encoder.layer1.0.bn1.running_var", "module.encoder.layer1.0.conv2.weight", "module.encoder.layer1.0.bn2.weight", "module.encoder.layer1.0.bn2.bias", "module.encoder.layer1.0.bn2.running_mean", "module.encoder.layer1.0.bn2.running_var", "module.encoder.layer1.0.conv3.weight", "module.encoder.layer1.0.bn3.weight", "module.encoder.layer1.0.bn3.bias", "module.encoder.layer1.0.bn3.running_mean", "module.encoder.layer1.0.bn3.running_var", "module.encoder.layer1.0.downsample.0.weight", "module.encoder.layer1.0.downsample.1.weight", "module.encoder.layer1.0.downsample.1.bias", "module.encoder.layer1.0.downsample.1.running_mean", "module.encoder.layer1.0.downsample.1.running_var", "module.encoder.layer1.1.conv1.weight", "module.encoder.layer1.1.bn1.weight", "module.encoder.layer1.1.bn1.bias", "module.encoder.layer1.1.bn1.running_mean", "module.encoder.layer1.1.bn1.running_var", "module.encoder.layer1.1.conv2.weight", "module.encoder.layer1.1.bn2.weight", "module.encoder.layer1.1.bn2.bias", "module.encoder.layer1.1.bn2.running_mean", "module.encoder.layer1.1.bn2.running_var", "module.encoder.layer1.1.conv3.weight", "module.encoder.layer...
Unexpected key(s) in state_dict: "encoder.conv1.weight", "encoder.bn1.weight", "encoder.bn1.bias", "encoder.bn1.running_mean", "encoder.bn1.running_var", "encoder.bn1.num_batches_tracked", "encoder.layer1.0.conv1.weight", "encoder.layer1.0.bn1.weight", "encoder.layer1.0.bn1.bias", "encoder.layer1.0.bn1.running_mean", "encoder.layer1.0.bn1.running_var", "encoder.layer1.0.bn1.num_batches_tracked", "encoder.layer1.0.conv2.weight", "encoder.layer1.0.bn2.weight", "encoder.layer1.0.bn2.bias", "encoder.layer1.0.bn2.running_mean", "encoder.layer1.0.bn2.running_var", "encoder.layer1.0.bn2.num_batches_tracked", "encoder.layer1.1.conv1.weight", "encoder.layer1.1.bn1.weight", "encoder.layer1.1.bn1.bias", "encoder.layer1.1.bn1.running_mean", "encoder.layer1.1.bn1.running_var", "encoder.layer1.1.bn1.num_batches_tracked", "encoder.layer1.1.conv2.weight", "encoder.layer1.1.bn2.weight", "encoder.layer1.1.bn2.bias", "encoder.layer1.1.bn2.running_mean", "encoder.layer1.1.bn2.running_var", "encoder.layer1.1.bn2.num_batches_tracked", "encoder.layer1.2.conv1.weight", "encoder.layer1.2.bn1.weight", "encoder.layer1.2.bn1.bias", "encoder.layer1.2.bn1.running_mean", "encoder.layer1.2.bn1.running_var", "encoder.layer1.2.bn1.num_batches_tracked", "encoder.layer1.2.conv2.weight", "encoder.layer1.2.bn2.weight", "encoder.layer1.2.bn2.bias", "encoder.layer1.2.bn2.running_mean", "encoder.layer1.2.bn2.running_var", "encoder.layer1.2.bn2.num_batches_tracked", "encoder.layer2.0.conv1.weight", "encoder.layer...
我究竟做错了什么?我怎样才能解决这个问题?对此的任何帮助都会有所帮助。
解决方案
这一行:
model.load_state_dict(checkpoint['model_state_dict'])
应该是这样的:
model.load_state_dict(checkpoint)
推荐阅读
- bash - 如何编写一个可以切换到 bash 的 shell?
- mysql - mysql查询的重复结果
- windows - 从 Windows 批处理文件重定向
- python - 如何绘制具有两个唯一值的数据框列?
- javascript - 获取 HTML 属性。与 ViewContainerRef?
- c++ - 如何使用正则表达式将输入数字格式化为字符串?
- java - Java 运行时环境检测到一个致命错误 多次关闭并重新打开网络摄像头时
- angular - 输入文件上传在 Angular 材质菜单中不起作用
- google-bigquery - 从 Google 表格创建表格时,如何解决 Google Bigquery 中的错误?
- javascript - 准确的javascript睡眠功能