首页 > 解决方案 > 我是否需要加载我在 NN 类中使用的另一个类的权重?

问题描述

我有一个需要实现自我注意的模型,这就是我编写代码的方式:

class SelfAttention(nn.Module):
    def __init__(self, args):
        self.multihead_attn = torch.nn.MultiheadAttention(args)
        
    def foward(self, x):
        return self.multihead_attn.forward(x, x, x)
    
class ActualModel(nn.Module):
    def __init__(self):
        self.inp_layer = nn.Linear(arg1, arg2)
        self.self_attention = SelfAttention(some_args)
        self.out_layer = nn.Linear(arg2, 1)
    
    def forward(self, x):
        x = self.inp_layer(x)
        x = self.self_attention(x)
        x = self.out_layer(x)
        return x

加载检查点后ActualModelActualModel.__init__在继续训练期间或预测期间,我应该加载保存的模型检查点类SelfAttention吗?

如果我创建一个 class 的实例,如果我这样做SelfAttention了,对应的训练权重会SelfAttention.multihead_attn被加载torch.load(actual_model.pth)还是会被重新初始化?

换句话说,这有必要吗?

class ActualModel(nn.Module):
    
    def __init__(self):
        self.inp_layer = nn.Linear(arg1, arg2)
        self.self_attention = SelfAttention(some_args)
        self.out_layer = nn.Linear(arg2, 1)
        
    def pred_or_continue_train(self):
        self.self_attention = torch.load('self_attention.pth')

actual_model = torch.load('actual_model.pth')
actual_model.pred_or_continue_training()
actual_model.eval()

标签: pythonpytorchartificial-intelligence

解决方案


换句话说,这有必要吗?

简而言之,

如果SelfAttention该类已注册为 nn.module、nn.Parameters 或手动注册的缓冲区,则该类将被自动加载。

一个简单的例子:

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, fin, n_h):
        super(SelfAttention, self).__init__()
        self.multihead_attn = torch.nn.MultiheadAttention(fin, n_h)
        
    def foward(self, x):
        return self.multihead_attn.forward(x, x, x)
    
class ActualModel(nn.Module):
    def __init__(self):
        super(ActualModel, self).__init__()
        self.inp_layer = nn.Linear(10, 20)
        self.self_attention = SelfAttention(20, 1)
        self.out_layer = nn.Linear(20, 1)
    
    def forward(self, x):
        x = self.inp_layer(x)
        x = self.self_attention(x)
        x = self.out_layer(x)
        return x

m = ActualModel()
for k, v in m.named_parameters():
    print(k)

你会得到如下,在哪里self_attention注册成功。

inp_layer.weight
inp_layer.bias
self_attention.multihead_attn.in_proj_weight
self_attention.multihead_attn.in_proj_bias
self_attention.multihead_attn.out_proj.weight
self_attention.multihead_attn.out_proj.bias
out_layer.weight
out_layer.bias

推荐阅读