pytorch - 在 PyTorch 中根据参数计算简单 NN 的 Hessian 矩阵
问题描述
我对 PyTorch 比较陌生,并试图计算一个非常简单的前馈网络相对于其权重的 Hessian。我正在尝试让torch.autograd.functional.hessian工作。我一直在挖掘论坛,因为这是添加到 PyTorch 中的一个相对较新的功能,所以我无法找到关于它的大量信息。这是我的简单网络架构,它来自 Mnist 上的 Kaggle 上的一些示例代码。
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.l1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.l3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = self.l1(x)
x = self.relu(x)
x = self.l3(x)
return F.log_softmax(x, dim = 1)
net = Network()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
loss_func = nn.CrossEntropyLoss()
我正在为一堆时代运行 NN,例如:
for e in range(epochs):
for i in range(0, x.shape[0], batch_size):
x_mini = x[i:i + batch_size]
y_mini = y[i:i + batch_size]
x_var = Variable(x_mini)
y_var = Variable(y_mini)
optimizer.zero_grad()
net_out = net(x_var)
loss = loss_func(net_out, y_var)
loss.backward()
optimizer.step()
if i % 100 == 0:
loss_log.append(loss.data)
然后,我将所有参数添加到一个列表中,并从中创建一个张量,如下所示:
param_list = []
for param in net.parameters():
param_list.append(param.view(-1))
param_list = torch.cat(param_list)
最后,我试图通过运行来计算融合网络的 Hessian:
hessian = torch.autograd.functional.hessian(loss_func, param_list,create_graph=True)
但它给了我这个错误: TypeError: forward() missing 1 required positional argument: 'target'
任何帮助,将不胜感激。
解决方案
关于模型的参数(与模型的输入相反)计算粗麻布目前并没有得到很好的支持。在https://github.com/pytorch/pytorch/issues/49171上正在做一些工作,但目前非常不方便。
您的代码还有一些其他问题 - 在您传递的地方loss_func
,您应该传递一个构造计算图的函数。此外,您永远不会指定网络的输入或损失函数的目标。
这是一些使用现有功能接口计算模型权重的粗麻布的代码,并将所有内容连接在一起以提供与您尝试做的相同的形式:
# Pick a random input to the network
src = torch.rand(1, 2)
# Say our target for our loss is all ones
dst = torch.ones(1, dtype=torch.long)
keys = list(net.state_dict().keys())
parameters = list(net.parameters())
sizes = [x.view(-1).shape[0] for x in parameters]
ndims = sum(sizes)
def hessian_hack(*params):
for i in range(len(keys)):
path = keys[i].split('.')
cur = net
for f in range(0, len(path)-1):
cur = net.__getattr__(path[f])
cur.__delattr__(path[-1])
cur.__setattr__(path[-1], params[i])
return loss_func(net(src), dst)
# sub_hessians[i][f] is the hessian of parameter i vs parameter f
sub_hessians = torch.autograd.functional.hessian(
hessian_hack,
tuple(parameters),
create_graph=True)
# We can combine them all into a nice big hessian.
hessian = torch.cat([
torch.cat([
sub_hessians[i][f].reshape(sizes[i], sizes[f])
for f in range(len(sub_hessians[i]))
], axis=1)
for i in range(len(sub_hessians))
], axis=0)
print(hessian)
推荐阅读
- python - pymc3中的转换率
- javascript - 我可以在一个函数中堆叠不同的 javascript 方法吗?
- python - 3D 曲面图:如何反转轴
- android - Android Studio“底部导航活动”模板在顶部留下空白区域
- javascript - 在 Vue 中观察异步外部 DOM 变化
- powershell - 提取 xml 标记时将文件名附加到输出
- javascript - 将 Json 数据传递给 HTML
- regex - 正则表达式从带小数的文本中提取十进制数
- c# - 如何返回两个值但只在文本框中显示一个并在 C# 中隐藏另一个
- node.js - 为编写 nodejs/nestjs 的微服务创建自己的 API 网关?