python - 如何摆脱 PyTorch.autograd 中的变量 API?
问题描述
我正在X
通过两个简单的 nn.Module PyTorch 模型实例model1
和model2
.
如果不使用折旧的Variable
API,我无法让这个过程正常工作。
所以这很好用:
y1 = model1(X)
v = Variable(y1.data, requires_grad=training) # Its all about this line!
y2 = model2(v)
criterion = nn.NLLLoss()
loss = criterion(y2, y)
loss.backward()
y1.backward(v.grad)
self.step()
但这会引发错误:
y1 = model1(X)
y2 = model2(y1)
criterion = nn.NLLLoss()
loss = criterion(y2, y)
loss.backward()
y1.backward(y1.grad) # it breaks here
self.step()
>>> RuntimeError: grad can be implicitly created only for scalar outputs
我似乎无法v
在第一个实现和y1
第二个实现之间找到相关的区别。在这两种情况下requires_grad
都设置为True
。我唯一能找到的y1.grad_fn=<ThnnConv2DBackward>
是v.grad_fn=<ThnnConv2DBackward>
我在这里想念什么?我不知道什么(张量属性?),如果Variable
折旧了,还有什么其他实现可以工作?
解决方案
经过一番调查,我得出了以下两个解决方案。该线程中其他地方提供的解决方案手动保留了计算图,没有释放它们的选项,因此最初运行良好,但后来导致 OOM 错误。
第一个解决方案是使用内置函数将模型绑定在一起torch.nn.Sequential
:
model = torch.nn.Sequential(Model1(), Model2())
就这么简单。它看起来很干净,行为与普通模型完全一样。
另一种方法是简单地将它们手动绑定在一起:
model1 = Model1()
model2 = Model2()
y1 = model1(X)
y2 = model2(y1)
loss = criterion(y2, y)
loss.backward()
我担心这只会反向传播model2
被证明是没有根据的,因为model1
它也存储在反向传播的计算图中。与以前的实现相比,这种实现提高了两个模型之间的接口的透明度。
推荐阅读
- mysql - 将 5.1 升级到 5.7 后,mysql 在 CentOS 6.9 中无法启动
- visual-studio-code - 如何与 VS Code 稳定版和内部版本共享扩展和设置?
- laravel - Laravel Foreach 循环
- tensorflow - `tf.data.Dataset.map` 是否保留输入顺序?
- ruby-on-rails - 如何在 CSV 导入中自动设置某些行值到 Ruby on Rails 应用程序中
- python - 在 GCP 上运行时出错:意外的关键字参数“maximum_iterations”
- javascript - 如何在 node.js 中过滤 Twit (API) JSON 响应?
- c# - 在 C# 中删除单个数据集中相对于另一个数据集中的重复项
- android - 无法解析符号“@style/Widget.Design.CoordinatorLayout”
- c# - 如何用另一个代码调用 Input.GetButtonDown?