首页 > 解决方案 > 恢复 NN 前向传递的 const 正确性

问题描述

我正在尝试使用 pytorch/libtorch 实现一个简单的神经网络。以下示例改编自libtorch cpp 前端教程

#include <torch/torch.h>
struct DeepQImpl : torch::nn::Module {
    DeepQImpl(size_t N)
        : linear1(2,5),
          linear2(5,3) {}
    torch::Tensor forward(torch::Tensor x) const {
        x = torch::tanh(linear1(x));
        x = linear2(x);
        return x;
    }
    torch::nn::Linear linear1, linear2;
};
TORCH_MODULE(DeepQ);

请注意,该函数forward已声明const。我正在编写的代码要求 NN 的评估是一个 const 函数,这对我来说似乎是合理的。但是,此代码无法编译。编译器抛出

错误:不匹配调用 '(const torch::nn::Linear) (at::Tensor&)'<br> x = linear1(x);

我已经找到了解决这个问题的方法,通过将图层定义为mutable

#include <torch/torch.h>
struct DeepQImpl : torch::nn::Module {
    /* all the code */
    mutable torch::nn:Linear linear1, linear2;
};

所以我的问题是

  1. 为什么在张量上应用层不是const
  2. 正在使用mutable这种方法来解决这个问题,它安全吗?

我的直觉是,在前向传播中,层被组装成一个可用于反向传播的结构,需要一些写入操作。如果这是真的,那么问题就变成了如何在第一步(非const)中组装层,然后在第二步(const)中评估结构。

标签: libtorch

解决方案


推荐阅读