python - Pytorch:输出wrt参数的梯度
问题描述
我有兴趣找到关于参数(权重和偏差)的神经网络输出的梯度。
更具体地说,假设我有以下神经网络结构 [6,4,3,1]。输入样本大小是 20。我感兴趣的是找到权重(和偏差)的神经网络输出的梯度,如果我没记错的话,在这种情况下是 47。在文献中,这个梯度有时称为 Weight_Jacobian。
我在 Jupyter Notebook 上的 Python 3.6 上使用 Pytorch 0.4.0 版。
我制作的代码是这样的:
def init_params(layer_sizes, scale=0.1, rs=npr.RandomState(0)):
return [(rs.randn(insize, outsize) * scale, # weight matrix
rs.randn(outsize) * scale) # bias vector
for insize, outsize in
zip(layer_sizes[:-1],layer_sizes[1:])]
layers = [6, 4, 3, 1]
w = init_params(layers)
first_layer_w = Variable(torch.tensor(w[0][0],requires_grad=True))
first_layer_bias = Variable(torch.tensor(w[0][1],requires_grad=True))
second_layer_w = Variable(torch.tensor(w[1][0],requires_grad=True))
second_layer_bias = Variable(torch.tensor(w[1][1],requires_grad=True))
third_layer_w = Variable(torch.tensor(w[2][0],requires_grad=True))
third_layer_bias = Variable(torch.tensor(w[2][1],requires_grad=True))
X = Variable(torch.tensor(X_batch),requires_grad=True)
output=torch.tanh(torch.mm(torch.tanh(torch.mm(torch.tanh(torch.mm(X,first_layer_w)+first_layer_bias),second_layer_w)+second_layer_bias),third_layer_w)+third_layer_bias)
output.backward()
从代码中可以明显看出,我使用双曲正切作为非线性。该代码生成长度为 20 的输出向量。现在,我有兴趣在所有权重(全部 47 个)中找到此输出向量的梯度。我在这里阅读了 Pytorch 的文档。例如,我也在 这里看到了类似的问题。但是,我没能找到输出向量 wrt 参数的梯度。如果我使用 Pytorch 函数backward(),它会生成一个错误
RuntimeError: grad can be implicitly created only for scalar outputs
我的问题是,有没有办法计算输出向量 wrt 参数的梯度,它基本上可以表示为 20*47 矩阵,因为我的输出向量的大小为 20,参数向量的大小为 47?如果是这样,如何?我的代码有什么问题吗?你可以举任何 X 的例子,只要它的尺寸是 20*6。
解决方案
您正在尝试计算函数的雅可比行列式,而 PyTorch 期望您计算向量-雅可比行积。您可以在此处查看使用 PyTorch 计算雅可比矩阵的深入讨论。
你有两个选择。您的第一个选择是使用JAX或autograd 并使用 jacobian() 函数。您的第二个选择是坚持使用 Pytorch 并通过调用backwards(vec)
20 次来计算 20 个向量雅可比乘积,其中vec
是一个长度为 20 的 one-hot 向量,其中 1 的组件的索引范围从 0 到 19。如果这令人困惑,我建议阅读 JAX 教程中的autodiff 食谱。
推荐阅读
- php - 从数据库中获取具有一对多关系的特定产品
- html - 使元素旋转并溢出到顶部+圆形内边框
- apache - 自托管网络服务器并保持“匿名”的最佳方式?
- python - python中是否有任何模块用于获取GMToffset
- c++ - 比较浮点数时是否有更好的方法来检查相反的符号?
- heroku - Heroku:来自容器和 Heroku PostgreSQL 的连接速度很慢
- pine-script - 为什么我的 PineScript 代码不起作用?
- javascript - 打开包含 700 张照片的页面时页面崩溃和滞后
- mysql - Sequelize:删除关联的行
- c++ - 如何在这个 GTK4 表单中添加两个按钮?