python - pytorch.load 并保存 - 尝试继续训练时加载 state_dict 时出错
问题描述
我收到以下错误:
RuntimeError: Error(s) in loading state_dict for XceptionHourglass: Missing key(s) in state_dict: "conv1.weight", "conv1.bias", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "conv2.weight", "conv2.bias", "bn2.weight", "bn2.bias", "bn2.running_mean", "bn2.running_var".....,
我开始训练:
model = train_mask_net(64)
这调用了函数 train_mask_net,其中我在 epoch 循环中包含了 torch.save。我想加载其中一个保存的模型并在循环前继续使用 torch.load 进行训练。
def train_mask_net(num_epochs=1):
data = MaskDataset(list(data_mask.keys()))
data_loader = torch.utils.data.DataLoader(data, batch_size=8, shuffle=True, num_workers=4)
model = XceptionHourglass(max_clz+2)
model.cuda()
dp = torch.nn.DataParallel(model)
loss = nn.CrossEntropyLoss()
params = [p for p in dp.parameters() if p.requires_grad]
optimizer = torch.optim.RMSprop(params, lr=2.5e-4, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=6,
gamma=0.9)
checkpoint = torch.load('imaterialist2020-pretrain-models/maskmodel_160.model_ep4_tsave')
#print(checkpoint)
model.load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)
epoch = checkpoint['epoch']
loss = checkpoint['loss']
#print('epoch', epoch)
for epoch in range(num_epochs):
print('epoch', epoch)
#print('loss in epoch', loss)
total_loss = []
prog = tqdm(data_loader, total=len(data_loader))
for i, (imag, mask) in enumerate(prog):
X = imag.cuda()
y = mask.cuda()
xx = dp(X)
# to 1D-array
y = y.reshape((y.size(0),-1)) # batch, flatten-img
y = y.reshape((y.size(0) * y.size(1),)) # flatten-all
xx = xx.reshape((xx.size(0), xx.size(1), -1)) # batch, channel, flatten-img
xx = torch.transpose(xx, 2, 1) # batch, flatten-img, channel
xx = xx.reshape((xx.size(0) * xx.size(1),-1)) # flatten-all, channel
losses = loss(xx, y)
prog.set_description("loss:%05f"%losses)
optimizer.zero_grad()
losses.backward()
optimizer.step()
total_loss.append(losses.detach().cpu().numpy())
torch.save({'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'loss': loss,
'epoch': epoch,
}, MODEL_FILE_DIR+"maskmodel_%d.model"%attr_image_size[0]+'_ep'+str(epoch)+'_tsave')
#torch.save(model.state_dict(), MODEL_FILE_DIR+"maskmodel_%d.model"%attr_image_size[0]+'_ep'+str(epoch)+'_tsave')
prog, X, xx, y, losses = None, None, None, None, None,
torch.cuda.empty_cache()
gc.collect()
return model
我认为没有必要,但是 xceptionhour 类看起来像这样:
class XceptionHourglass(nn.Module):
def __init__(self, num_classes):
super(XceptionHourglass, self).__init__()
self.num_classes = num_classes
self.conv1 = nn.Conv2d(3, 128, 3, 2, 1, bias=True)
self.bn1 = nn.BatchNorm2d(128)
self.mish = Mish()
self.conv2 = nn.Conv2d(128, 256, 3, 1, 1, bias=True)
self.bn2 = nn.BatchNorm2d(256)
self.block1 = HourglassNet(4, 256)
self.bn3 = nn.BatchNorm2d(256)
self.block2 = HourglassNet(4, 256)
...
解决方案
您在以下几.load_state_dict()
行中犯了一个错误
model.load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)
将会
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
推荐阅读
- reactjs - Firebase 错误:未创建 Firebase 应用默认值
- python - If 语句中未定义变量(Python,Ursina 模块)
- html - 清理 wordpress woocommerce 产品描述的 HTML?
- python - 在没有 NVIDIA GPU 的情况下使用 CUDA?
- functional-programming - 榆树指导骰子练习
- dialogflow-es - Dialogflow 知识库 CSV 上传错误
- sed - 使用 Sed 将基本文件名解析为替换参数
- javascript - 哪种设计模式可以很好地接收来自不同前端的出价到 javascript 中的单个后端
- arrays - 在不创建新数组的情况下替换数组中元素的值
- ruby-on-rails - 在 Rails 中通过 Ruby GraphQL 使用变量