首页 > 解决方案 > 在 TensorFlow Lite C API 中注册自定义算子

问题描述

我正在使用 C API 在 Android 上运行 tensorflow lite。我的模型需要RandomStandardNormal最近在 tensorflow 中作为自定义操作原型实现的运算v2.4.0-rc0

TfLiteInterpreterOptionsAddCustomOp()函数在tensorflow/lite/c/c_api_experimental.h中列出:

TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddCustomOp(
    TfLiteInterpreterOptions* options, const char* name,
    const TfLiteRegistration* registration, int32_t min_version,
    int32_t max_version);

这个例子&线程,我试图这样使用TfLiteInterpreterOptionsAddCustomOp

// create model and interpreter options
TfLiteModel *model = TfLiteModelCreateFromFile("path/to/model.tflite");
TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();

// register custom ops
TfLiteInterpreterOptionsAddCustomOp(options, "RandomStandardNormal", Register_RANDOM_STANDARD_NORMAL(), 1, 1);

// create the interpreter
TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
TfLiteInterpreterAllocateTensors(interpreter);

我看到该Register_RANDOM_STANDARD_NORMAL()函数是在tensorflow/lite/kernels/custom_ops_register.htflite::ops::custom的C++ 命名空间中定义的。但是,当我尝试将它包含在我的 C 文件中时,编译器会抱怨,因为它是 C 中的未知类型。namespace

如何使用 tensorflow lite C API 注册自定义运算符?我是否需要使用 C++ 编译器才能将 C API 与此自定义运算符一起使用,因为它是在 C++ 中定义的?

注意:我//tensorflow/lite/kernels:custom_ops在编译时包含在 bazel BUILD deps 中libtensorflowlite_c.so

标签: androidctensorflowtensorflow-lite

解决方案


看起来这是通过以下解决方法在 Github 上回答的:

https://github.com/tensorflow/tensorflow/issues/44664#issuecomment-723310060


推荐阅读