python - 我是否需要加载我在 NN 类中使用的另一个类的权重?
问题描述
我有一个需要实现自我注意的模型,这就是我编写代码的方式:
class SelfAttention(nn.Module):
def __init__(self, args):
self.multihead_attn = torch.nn.MultiheadAttention(args)
def foward(self, x):
return self.multihead_attn.forward(x, x, x)
class ActualModel(nn.Module):
def __init__(self):
self.inp_layer = nn.Linear(arg1, arg2)
self.self_attention = SelfAttention(some_args)
self.out_layer = nn.Linear(arg2, 1)
def forward(self, x):
x = self.inp_layer(x)
x = self.self_attention(x)
x = self.out_layer(x)
return x
加载检查点后ActualModel
,ActualModel.__init__
在继续训练期间或预测期间,我应该加载保存的模型检查点类SelfAttention
吗?
如果我创建一个 class 的实例,如果我这样做SelfAttention
了,对应的训练权重会SelfAttention.multihead_attn
被加载torch.load(actual_model.pth)
还是会被重新初始化?
换句话说,这有必要吗?
class ActualModel(nn.Module):
def __init__(self):
self.inp_layer = nn.Linear(arg1, arg2)
self.self_attention = SelfAttention(some_args)
self.out_layer = nn.Linear(arg2, 1)
def pred_or_continue_train(self):
self.self_attention = torch.load('self_attention.pth')
actual_model = torch.load('actual_model.pth')
actual_model.pred_or_continue_training()
actual_model.eval()
解决方案
换句话说,这有必要吗?
简而言之,不。
如果SelfAttention
该类已注册为 nn.module、nn.Parameters 或手动注册的缓冲区,则该类将被自动加载。
一个简单的例子:
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, fin, n_h):
super(SelfAttention, self).__init__()
self.multihead_attn = torch.nn.MultiheadAttention(fin, n_h)
def foward(self, x):
return self.multihead_attn.forward(x, x, x)
class ActualModel(nn.Module):
def __init__(self):
super(ActualModel, self).__init__()
self.inp_layer = nn.Linear(10, 20)
self.self_attention = SelfAttention(20, 1)
self.out_layer = nn.Linear(20, 1)
def forward(self, x):
x = self.inp_layer(x)
x = self.self_attention(x)
x = self.out_layer(x)
return x
m = ActualModel()
for k, v in m.named_parameters():
print(k)
你会得到如下,在哪里self_attention
注册成功。
inp_layer.weight
inp_layer.bias
self_attention.multihead_attn.in_proj_weight
self_attention.multihead_attn.in_proj_bias
self_attention.multihead_attn.out_proj.weight
self_attention.multihead_attn.out_proj.bias
out_layer.weight
out_layer.bias
推荐阅读
- c++ - std::make_shared 和受保护/私有构造函数
- java - 使用此代码后我的 Eclipse 崩溃(cpu 100%),任何人都可以确认代码是否有效
- android - 编辑文本上的双向绑定Android
- go - fmt.Println() 函数中的意外输出
- algorithm - 为什么我们不能总是选择带有大 O 符号的最大项?
- python - 在 Tkinter 画布中删除元素的问题
- c++ - 解决hackerearth平台给出的问题
- javascript - 用于单一导入的 Jest 模拟模块
- python - BigQuery 加载作业在 JSON 中的布尔数据类型字段上失败
- java - 引起:java.lang.ClassNotFoundException:找不到类“com.google.android.gms.common.internal.zzbq”