pytorch - 火炬。优化模型的输入:尝试第二次向后遍历图形,但缓冲区已被释放
问题描述
目前,我正在尝试将输入张量 x 的值优化为模型。我想将输入限制为仅包含 [0.0;1.0] 范围内的值。
当不使用这样的层时,没有太多关于如何执行此操作的信息。
我在下面创建了一个最小的工作示例,它在这篇文章的标题中给出了错误消息。
魔法发生在 optimize_x() 函数中
如果我注释掉这一行:model.x = model.x.clamp(min=0.0, max=1.0)
问题已解决,但张量显然没有被钳制。
我知道我可以设置retain_graph=True
- 但目前尚不清楚这是否是正确的方法,或者是否有更好的方法来实现此功能?
import torch
from torch.distributions import Uniform
class OptimizeInputModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = torch.nn.Sequential(
torch.nn.Linear(123, 1000),
torch.nn.Dropout(0.4),
torch.nn.ReLU(),
torch.nn.Linear(1000, 100),
torch.nn.Dropout(0.4),
torch.nn.ReLU(),
torch.nn.Linear(100, 1),
torch.nn.Sigmoid(),
)
in_shape = (1, 123)
self.x = torch.ones(in_shape) * 0.1
self.x.requires_grad = True
def forward(self) -> torch.Tensor:
return self.model(self.x)
class MyLossFunc(torch.nn.Module):
def forward(self, y: torch.Tensor) -> torch.Tensor:
loss = torch.sum(-y)
return loss
def optimize_x():
model = OptimizeInputModel()
optimizer = torch.optim.Adam([model.x], lr=1e-4)
loss_fn = MyLossFunc()
for epoch in range(50000):
# Constrain X to have no values < 0
model.x = model.x.clamp(min=0.0, max=1.0)
y = model()
loss = loss_fn(y)
if epoch % 9 == 0:
print(f'Epoch: {epoch}\t Loss: {loss}')
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimize_x()
完整的错误信息:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
解决方案
对于将来可能有同样问题的任何人。
我的解决方案是这样做(注意下划线!):
model.x.data.clamp_(min=0.0, max=1.0)
代替:
model.x = model.x.clamp(min=0.0, max=1.0)
推荐阅读
- html - 为什么这是垂直滚动而不是水平滚动?
- c# - 如何对 ConcurrentBag 进行排序?
- android - Android 辅助功能设置(对讲)焦点
- ios - 调度组错误
- javascript - 如何在使用 ReactJS 过滤元素之前渲染元素?
- ios - 如何滚动到viewdidload时发送的最后一条消息?在迅速 4
- excel - 运行时错误 5 - Excel 日期切片器
- pygame - 在我尝试在游戏中实现碰撞和重力后,我的屏幕变黑了
- forms - Xamarin Forms SearchBar + ListView 更新缓慢
- java - Libgdx ResolutionFileResolver 无法正常工作