首页 > 解决方案 > 如何使用 pybind11 将 python tf.Tensor 转换为 C++ Tensor

问题描述

我正在使用 pybind11 在 C++ 中开发 Python 库。该库旨在与 TensorFlow 等深度学习包一起使用,因此它公开的一些函数将 TensorFlow 对象作为参数,在特定tf.Tensor情况下也是如此。在 C++ 级别,这些函数使用以下原型定义:

void f(const py11::object& theTensor) {
   ...
}

我正在寻找一种方法来让正在传递的 C++tensorflow::Tensor包裹起来。py11::object

我对 TensorFlow 的理解是它使用 SWIG 为 C++ 接口创建 Python 包装器。但是,我找不到tf.Tensor包装器的定义位置,而且我不熟悉 SWIG 将 C++ 类公开为 Python 类的方式。

我应该能够从using检索PyObject*指针,但是我不知道如何将其转换为指向 C++ Tensor 对象的指针。py11::objecttheTensor.ptr()

标签: pythonc++tensorflowswigpybind11

解决方案


推荐阅读