c++ - 使用 LibTorch (PyTorch) 时在 C++ 中将 at::Tensor 转换为双精度
问题描述
在下面的代码中,我想将loss
(data type at::Tensor
) 与lossThreshold
(data type double
) 进行比较。我想在进行比较之前转换loss
为。double
我该怎么做?
int main() {
auto const input1(torch::randn({28*28});
auto const input2(torch::randn({28*28});
double const lossThreshold{0.05};
auto const loss{torch::nn::functional::mse_loss(input1, input2)}; // this returns an at::Tensor datatype
return loss > lossThreshold ? EXIT_FAILURE : EXIT_SUCCESS;
}
解决方案
感谢 GitHub CoPilot 推荐了这个解决方案。我想我现在应该辞职了。:(
解决方案是使用item<T>()
模板函数如下:
int main() {
auto const input1(torch::randn({28*28}); // at::Tensor
auto const input2(torch::randn({28*28}); // at::Tensor
double const lossThreshold{0.05}; // double
auto const loss{torch::nn::functional::mse_loss(input1, input2).item<double>()}; // the item<double>() converts at::Tensor to double
return loss > lossThreshold ? EXIT_FAILURE : EXIT_SUCCESS;
}
推荐阅读
- javascript - Show remaining months in year
- javascript - 将image/jpeg、image/png等文件转换为multipart/form-data格式节点js
- ionic4 - Ionic 4. 打开 GPS 后 getCurrentPosition 不起作用
- docusignapi - Docusign Postman 集合缺少变量
- javascript - 我可以对对象数组进行二进制搜索吗?
- excel - 如何根据特定文本剪切整行
- java - 点击通知后活动刷新
- sequence - 在 ADF 中使用复制活动添加顺序自定义列
- pine-script - 如何将代码从版本 #2 转换为版本 #4 的 pine 脚本?
- angular - Angular 在严格模式下使用 ngrx