首页 > 解决方案 > 每次迭代后保存 PyTorch VGG 模型的权重

问题描述

每次迭代后都会保存相同的 VGG 分类器权重。但是字典classifier_weights在每次迭代中都包含相同的权重。

下面是代码片段:

classifier_weights = {}
task_model.train()

for iter_count in range(self.args.train_iterations):
    if iter_count != 0 and iter_count % lr_change == 0:
        for param in optim_task_model.param_groups:
            param['lr'] = param['lr'] / 10

     # task_model step
        preds = task_model(labeled_imgs)
        task_loss = self.ce_loss(preds, labels)
        optim_task_model.zero_grad()
        task_loss.backward()
        optim_task_model.step()

        if iter_count > self.args.weight_save_iter:
            classifier_weights[iter_count - self.args.weight_save_iter - 1] = {}
            classifier_weights[iter_count - self.args.weight_save_iter - 1][0] = task_model.classifier[0].weight.data
            classifier_weights[iter_count - self.args.weight_save_iter - 1][1] = task_model.classifier[3].weight.data
            classifier_weights[iter_count - self.args.weight_save_iter - 1][2] = task_model.classifier[6].weight.data

标签: modelpytorch

解决方案


推荐阅读