首页 > 解决方案 > Pytorch C++ API:CMake 问题

问题描述

我想将 pytorch C++ API 包含到我正在开发的大型 C++ 软件中。

出于遗留原因,我必须使用find_package 和关联的find_pathandfind_library函数,而不是建议 target_link_libraries的.

这是我的 FindTORCH.cmake :

include( FindPackageHandleStandardArgs )

find_path( TORCH_INCLUDE_DIR torch/torch.h
           PATHS 
           /path/to/libtorch/include/torch/csrc/api/include/
           NO_DEFAULT_PATH )

find_library( TORCH_LIBRARIES libtorch.so
              PATHS
              /path/to/libtorch/lib/
              NO_DEFAULT_PATH )


FIND_PACKAGE_HANDLE_STANDARD_ARGS( TORCH REQUIRED_VARS TORCH_INCLUDE_DIR TORCH_LIBRARIES )

if ( TORCH_FOUND )
message( STATUS "Torch found" )
endif( TORCH_FOUND )

mark_as_advanced( TORCH_LIBRARIES TORCH_INCLUDE_DIR )

在编译时,找到了火炬文件,我可以include <torch/torch.h>在项目中的随机 .cxx 中找到。

但是,如果我添加到 .cxx :

torch::Tensor tensor = torch::rand({2, 3});
cout << tensor << std::endl;

然后我无法再编译,我收到以下错误:

/path/to/libtorch/include/torch/csrc/api/include/torch/utils.h:4:10: fatal error: ATen/record_function.h: No such file or directory
 #include <ATen/record_function.h>
          ^~~~~~~~~~~~~~~~~~~~~~~~
compilation terminated.

我正在使用 Ubuntu 18、C++ 14,并且 cmake 版本是 3.10.2 。

提前致谢

标签: c++cmakepytorch

解决方案


Torch 暴露了它自己的目标。要有效地使用它们,只需从您的项目中删除FindTORCH.cmake,并添加/path/to/libtorch/到您的前缀路径:

cmake_minimum_required(VERSION 3.19) # or whatever version you use
project(your-project CXX)

list(APPEND CMAKE_PREFIX_PATH "/path/to/libtorch/")
find_package(Torch REQUIRED CONFIG) # this ensure it find the file provided by pytorch

add_executable(your-executable main.cpp)

target_link_libraries(your-executable PUBLIC torch::Tensor)

如果你真的坚持使用你自己的FindTorch.cmake而不是正确的,你可以修改它来创建一个导入的目标,然后你将链接:

您可以稍微更改您的 find 模块以从中获得现代 CMake 界面:

include( FindPackageHandleStandardArgs )

find_path( TORCH_INCLUDE_DIR torch/torch.h
           PATHS 
           /path/to/libtorch/include/torch/csrc/api/include/
           NO_DEFAULT_PATH )

find_library( TORCH_LIBRARIES libtorch.so
              PATHS
              /path/to/libtorch/lib/
              NO_DEFAULT_PATH )


FIND_PACKAGE_HANDLE_STANDARD_ARGS( TORCH REQUIRED_VARS TORCH_INCLUDE_DIR TORCH_LIBRARIES )

if ( TORCH_FOUND )
    message( STATUS "Torch found" )
    add_library(torch::Tensor SHARED IMPORTED) # mimic the names from pytorch maintainers
    set_target_properties(torch::Tensor 
    PROPERTIES
        IMPORTED_LOCATION "${TORCH_LIBRARIES}"
        INTERFACE_INCLUDE_DIRECTORIES "${TORCH_INCLUDE_DIR}"
        # on windows, set IMPORTED_IMPLIB to the .lib
    )
endif( TORCH_FOUND )

mark_as_advanced( TORCH_LIBRARIES TORCH_INCLUDE_DIR )

然后,在您的主 CMake 文件中,您可以像使用任何其他目标一样使用导入的目标:

find_package(Torch REQUIRED)

add_executable(your-executable main.cpp)
target_link_libraries(your-executable PUBLIC torch::Tensor)

推荐阅读