首页 > 解决方案 > 火炬。优化模型的输入:尝试第二次向后遍历图形,但缓冲区已被释放

问题描述

目前,我正在尝试将输入张量 x 的值优化为模型。我想将输入限制为仅包含 [0.0;1.0] 范围内的值。

当不使用这样的层时,没有太多关于如何执行此操作的信息。

我在下面创建了一个最小的工作示例,它在这篇文章的标题中给出了错误消息。

魔法发生在 optimize_x() 函数中

如果我注释掉这一行:model.x = model.x.clamp(min=0.0, max=1.0)问题已解决,但张量显然没有被钳制。

我知道我可以设置retain_graph=True- 但目前尚不清楚这是否是正确的方法,或者是否有更好的方法来实现此功能?

import torch
from torch.distributions import Uniform


class OptimizeInputModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Sequential(
                torch.nn.Linear(123, 1000),
                torch.nn.Dropout(0.4),
                torch.nn.ReLU(),
                torch.nn.Linear(1000, 100),
                torch.nn.Dropout(0.4),
                torch.nn.ReLU(),
                torch.nn.Linear(100, 1),
                torch.nn.Sigmoid(),
        )

        in_shape = (1, 123)
        self.x = torch.ones(in_shape) * 0.1
        self.x.requires_grad = True

    def forward(self) -> torch.Tensor:
        return self.model(self.x)

class MyLossFunc(torch.nn.Module):

    def forward(self, y: torch.Tensor) -> torch.Tensor:
        loss = torch.sum(-y)
        return loss


def optimize_x():
    model = OptimizeInputModel()
    optimizer = torch.optim.Adam([model.x], lr=1e-4)
    loss_fn = MyLossFunc()
    for epoch in range(50000):
        # Constrain X to have no values < 0
        model.x = model.x.clamp(min=0.0, max=1.0)
        y = model()
        loss = loss_fn(y)

        if epoch % 9 == 0:
            print(f'Epoch: {epoch}\t Loss: {loss}')

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


optimize_x()

完整的错误信息: RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

标签: pytorch

解决方案


对于将来可能有同样问题的任何人。

我的解决方案是这样做(注意下划线!):

model.x.data.clamp_(min=0.0, max=1.0)

代替:

model.x = model.x.clamp(min=0.0, max=1.0)

推荐阅读