首页 > 解决方案 > 使用 LibTorch (PyTorch) 时在 C++ 中将 at::Tensor 转换为双精度

问题描述

在下面的代码中,我想将loss(data type at::Tensor) 与lossThreshold(data type double) 进行比较。我想在进行比较之前转换loss为。double我该怎么做?

int main() {
    auto const input1(torch::randn({28*28});
    auto const input2(torch::randn({28*28});
    double const lossThreshold{0.05};
    auto const loss{torch::nn::functional::mse_loss(input1, input2)}; // this returns an at::Tensor datatype
    return loss > lossThreshold ? EXIT_FAILURE : EXIT_SUCCESS;
}

标签: c++pytorchlibtorch

解决方案


感谢 GitHub CoPilot 推荐了这个解决方案。我想我现在应该辞职了。:(

解决方案是使用item<T>()模板函数如下:

int main() {
    auto const input1(torch::randn({28*28}); // at::Tensor
    auto const input2(torch::randn({28*28}); // at::Tensor
    double const lossThreshold{0.05}; // double
    auto const loss{torch::nn::functional::mse_loss(input1, input2).item<double>()}; // the item<double>() converts at::Tensor to double
    return loss > lossThreshold ? EXIT_FAILURE : EXIT_SUCCESS;
}

推荐阅读