首页 > 解决方案 > 在 PyTorch 中根据参数计算简单 NN 的 Hessian 矩阵

问题描述

我对 PyTorch 比较陌生,并试图计算一个非常简单的前馈网络相对于其权重的 Hessian。我正在尝试让torch.autograd.functional.hessian工作。我一直在挖掘论坛,因为这是添加到 PyTorch 中的一个相对较新的功能,所以我无法找到关于它的大量信息。这是我的简单网络架构,它来自 Mnist 上的 Kaggle 上的一些示例代码。

class Network(nn.Module):
    
    def __init__(self):
        super(Network, self).__init__()
        self.l1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.l3 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = self.l1(x)
        x = self.relu(x)
        x = self.l3(x)
        return F.log_softmax(x, dim = 1)
net = Network()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
loss_func = nn.CrossEntropyLoss()

我正在为一堆时代运行 NN,例如:

for e in range(epochs):
    for i in range(0, x.shape[0], batch_size):
        x_mini = x[i:i + batch_size] 
        y_mini = y[i:i + batch_size] 
        x_var = Variable(x_mini)
        y_var = Variable(y_mini)
        optimizer.zero_grad()
        net_out = net(x_var)
        loss = loss_func(net_out, y_var)
        loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            loss_log.append(loss.data)

然后,我将所有参数添加到一个列表中,并从中创建一个张量,如下所示:

param_list = []
for param in net.parameters():
    param_list.append(param.view(-1))
param_list = torch.cat(param_list)

最后,我试图通过运行来计算融合网络的 Hessian:

hessian = torch.autograd.functional.hessian(loss_func, param_list,create_graph=True)

但它给了我这个错误: TypeError: forward() missing 1 required positional argument: 'target'

任何帮助,将不胜感激。

标签: pytorchhessian

解决方案


关于模型的参数(与模型的输入相反)计算粗麻布目前并没有得到很好的支持。在https://github.com/pytorch/pytorch/issues/49171上正在做一些工作,但目前非常不方便。

您的代码还有一些其他问题 - 在您传递的地方loss_func,您应该传递一个构造计算图的函数。此外,您永远不会指定网络的输入或损失函数的目标。

这是一些使用现有功能接口计算模型权重的粗麻布的代码,并将所有内容连接在一起以提供与您尝试做的相同的形式:

# Pick a random input to the network                             
src = torch.rand(1, 2)                                           
# Say our target for our loss is all ones                        
dst = torch.ones(1, dtype=torch.long)                            
                                                                 
keys = list(net.state_dict().keys())                             
parameters = list(net.parameters())                              
sizes = [x.view(-1).shape[0] for x in parameters]                
ndims = sum(sizes)                                               
                                                                 
def hessian_hack(*params):                                       
    for i in range(len(keys)):                                   
        path = keys[i].split('.')                                
        cur = net                                                
        for f in range(0, len(path)-1):                          
            cur = net.__getattr__(path[f])                       
        cur.__delattr__(path[-1])                                
        cur.__setattr__(path[-1], params[i])                     
    return loss_func(net(src), dst)                              
                                                                 
# sub_hessians[i][f] is the hessian of parameter i vs parameter f
sub_hessians = torch.autograd.functional.hessian(                
    hessian_hack,                                                
    tuple(parameters),                                           
    create_graph=True)                                           
                                                                 
# We can combine them all into a nice big hessian.               
hessian = torch.cat([                                            
        torch.cat([                                              
            sub_hessians[i][f].reshape(sizes[i], sizes[f])       
            for f in range(len(sub_hessians[i]))                 
        ], axis=1)                                               
    for i in range(len(sub_hessians))                            
], axis=0)                                                       
print(hessian)                                                   

推荐阅读