c++ - TF2.2 为 C++ API 构建 libtensorflow_cc.so
问题描述
系统信息:
- 操作系统平台和发行版:Linux Ubuntu 18.04
- TensorFlow安装源码:source
- TensorFlow 版本:2.2.0 稳定版
- Python版本:python3
- Bazel 版本:根据 tf-2.2.0 的要求使用 Bazelisk 和 Bazel 2.0.0 版本
- GCC/编译器版本(如果从源代码编译):GCC-8
- CUDA/cuDNN 版本:没有 CUDA(暂时)
我做了什么:
从 github 获得 tensorflow:
git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git
cd tensorflow
git checkout v2.2.0
安装并设置 Bazelisk:(使用本指南:https ://gist.github.com/philwo/f3a8144e46168f23e40f291ffe92e63c )
$ sudo curl -Lo /usr/local/bin/bazel https://github.com/bazelbuild/bazelisk/releases/download/v1.1.0/bazelisk-linux-amd64
$ sudo chmod +x /usr/local/bin/bazel
$ grep -r _TF_MAX_BAZEL_VERSION .
./configure.py:_TF_MAX_BAZEL_VERSION = '2.0.0'
$ echo '2.0.0' > .bazelversion
$ bazel version
开始使用 Bazel 构建 tensorflow:
$ ./configure
bazel --host_jvm_args=-Xmx30G build --jobs=8 --config=monolithic --config=v2 --config=opt --copt=-O3 --copt=-march=native --copt=-m64 --verbose_failures //tensorflow:tensorflow //tensorflow:tensorflow_cc //tensorflow:tensorflow_framework //tensorflow/tools/lib_package:libtensorflow
然后用代码:
#include <stdlib.h>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#include "class_name.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
using namespace tensorflow;
using tensorflow::Flag;
using tensorflow::Status;
using tensorflow::string;
using tensorflow::Tensor;
//Read the image file, apply appropriate decoding depending on type of image
int TensorFromFile(string filename, const int i_height, const int i_width, std::vector<Tensor>* o_tensors) {
tensorflow::Status status;
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops;
std::unique_ptr<tensorflow::Session> session(tensorflow::NewSession({}));
tensorflow::GraphDef graph;
auto reader = tensorflow::ops::ReadFile(root.WithOpName("img_reader"), filename);
const int channels = 1;
tensorflow::Output imgreader;
if (tensorflow::str_util::EndsWith(filename, ".png")) {
imgreader = DecodePng(root.WithOpName("png_reader"), reader, DecodePng::Channels(channels));
} else if (tensorflow::str_util::EndsWith(filename, ".gif")) {
imgreader = DecodeGif(root.WithOpName("gif_reader"), reader);
} else {
imgreader = DecodeJpeg(root.WithOpName("jpeg_reader"), reader, DecodeJpeg::Channels(channels));
}
auto f_caster = Cast(root.WithOpName("float_caster"), imgreader, tensorflow::DT_FLOAT);
ExpandDims(root.WithOpName("output"), f_caster, 0);
status = root.ToGraphDef(&graph);
if (!status.ok()) {
LOG(ERROR) << status.ToString();
return -1;
}
status = session->Create(graph);
if (!status.ok()) {
LOG(ERROR) << status.ToString();
return -1;
}
status = session->Run({}, {"output"}, {}, o_tensors);
if (!status.ok()) {
LOG(ERROR) << status.ToString();
return -1;
}
return 0;
}
int main(int argc, char* argv[]) {
using namespace ::tensorflow::ops;
tensorflow::Status status;
std::string delimiter = ".";
std::string ofilename;
std::vector<Tensor> inputs;
std::vector<Tensor> outputs;
std::string graph_path = "../../graphs/test0/";
std::string image_path = "../../graphs/test0.png";
std::string mdlpath(graph_path);
std::string imgpath(image_path);
int32 inputdim = 32;
std::unique_ptr<tensorflow::Session> session(tensorflow::NewSession({}));
tensorflow::GraphDef graph;
//read model file
status = ReadBinaryProto(Env::Default(), mdlpath, &graph);
if (!status.ok()) {
std::cout << status.ToString() << "\n";
return -1;
}
//add graph to scope
status = session->Create(graph);
if (!status.ok()) {
std::cout << status.ToString() << "\n";
return -1;
}
//Read input image, assuming to be a sqaure image
if (TensorFromFile(imgpath, inputdim, inputdim, &inputs)) {
LOG(ERROR) << "Image reading failed"
<< "\n";
return -1;
}
LOG(INFO) << "OK";
std::cout << "input dimension of the image: " << inputs[0].DebugString() << std::endl;
//get the appropriate input and out layer names from the graph/mode to execute
auto inputlayer = graph.node(0).name();
auto outputlayer = graph.node(graph.node_size() - 1).name();
status = session->Run({{inputlayer, inputs[0]}}, {outputlayer}, {}, &outputs);
if (!status.ok()) {
LOG(ERROR) << status.ToString();
return -1;
}
std::cout << "Output dimension of the image" << outputs[0].DebugString() << std::endl;
//create filename
ofilename.append(imgpath.substr(0, imgpath.find(delimiter)));
ofilename.append("_mask.png");
std::cout << "output filename: " << ofilename << std::endl;
//Now write this to a image file
//if (TensorToFile(ofilename, outputs, threshold)) return -1;
session->Close();
return 0;
}
我试图编译:
f I use theese flags:
g++ -O3 -m64 -o test -I /opt/tpt/tensorflow_cpp_scratch/include/tensorflow/bazel-bin main.cpp -L /opt/tpt/tensorflow_cpp_scratch/lib -l tensorflow_cc
我得到了错误:
/tmp/cc4xzZGr.o: In function `google::protobuf::RepeatedPtrField<tensorflow::NodeDef>::TypeHandler::WeakType const& google::protobuf::internal::RepeatedPtrFieldBase::Get<google::protobuf::RepeatedPtrField<tensorflow::NodeDef>::TypeHandler>(int) const':
main.cpp:(.text._ZNK6google8protobuf8internal20RepeatedPtrFieldBase3GetINS0_16RepeatedPtrFieldIN10tensorflow7NodeDefEE11TypeHandlerEEERKNT_8WeakTypeEi[_ZNK6google8protobuf8internal20RepeatedPtrFieldBase3GetINS0_16RepeatedPtrFieldIN10tensorflow7NodeDefEE11TypeHandlerEEERKNT_8WeakTypeEi]+0x6a): undefined reference to `google::protobuf::internal::LogMessage::LogMessage(google::protobuf::LogLevel, char const*, int)'
main.cpp:(.text._ZNK6google8protobuf8internal20RepeatedPtrFieldBase3GetINS0_16RepeatedPtrFieldIN10tensorflow7NodeDefEE11TypeHandlerEEERKNT_8WeakTypeEi[_ZNK6google8protobuf8internal20RepeatedPtrFieldBase3GetINS0_16RepeatedPtrFieldIN10tensorflow7NodeDefEE11TypeHandlerEEERKNT_8WeakTypeEi]+0x79): undefined reference to `google::protobuf::internal::LogMessage::operator<<(char const*)'
main.cpp:(.text._ZNK6google8protobuf8internal20RepeatedPtrFieldBase3GetINS0_16RepeatedPtrFieldIN10tensorflow7NodeDefEE11TypeHandlerEEERKNT_8WeakTypeEi[_ZNK6google8protobuf8internal20RepeatedPtrFieldBase3GetINS0_16RepeatedPtrFieldIN10tensorflow7NodeDefEE11TypeHandlerEEERKNT_8WeakTypeEi]+0x86): undefined reference to `google::protobuf::internal::LogFinisher::operator=(google::protobuf::internal::LogMessage&)'
main.cpp:(.text._ZNK6google8protobuf8internal20RepeatedPtrFieldBase3GetINS0_16RepeatedPtrFieldIN10tensorflow7NodeDefEE11TypeHandlerEEERKNT_8WeakTypeEi[_ZNK6google8protobuf8internal20RepeatedPtrFieldBase3GetINS0_16RepeatedPtrFieldIN10tensorflow7NodeDefEE11TypeHandlerEEERKNT_8WeakTypeEi]+0x8e): undefined reference to `google::protobuf::internal::LogMessage::~LogMessage()'
main.cpp:(.text._ZNK6google8protobuf8internal20RepeatedPtrFieldBase3GetINS0_16RepeatedPtrFieldIN10tensorflow7NodeDefEE11TypeHandlerEEERKNT_8WeakTypeEi[_ZNK6google8protobuf8internal20RepeatedPtrFieldBase3GetINS0_16RepeatedPtrFieldIN10tensorflow7NodeDefEE11TypeHandlerEEERKNT_8WeakTypeEi]+0xb2): undefined reference to `google::protobuf::internal::LogMessage::LogMessage(google::protobuf::LogLevel, char const*, int)'
main.cpp:(.text._ZNK6google8protobuf8internal20RepeatedPtrFieldBase3GetINS0_16RepeatedPtrFieldIN10tensorflow7NodeDefEE11TypeHandlerEEERKNT_8WeakTypeEi[_ZNK6google8protobuf8internal20RepeatedPtrFieldBase3GetINS0_16RepeatedPtrFieldIN10tensorflow7NodeDefEE11TypeHandlerEEERKNT_8WeakTypeEi]+0xc1): undefined reference to `google::protobuf::internal::LogMessage::operator<<(char const*)'
main.cpp:(.text._ZNK6google8protobuf8internal20RepeatedPtrFieldBase3GetINS0_16RepeatedPtrFieldIN10tensorflow7NodeDefEE11TypeHandlerEEERKNT_8WeakTypeEi[_ZNK6google8protobuf8internal20RepeatedPtrFieldBase3GetINS0_16RepeatedPtrFieldIN10tensorflow7NodeDefEE11TypeHandlerEEERKNT_8WeakTypeEi]+0xce): undefined reference to `google::protobuf::internal::LogFinisher::operator=(google::protobuf::internal::LogMessage&)'
main.cpp:(.text._ZNK6google8protobuf8internal20RepeatedPtrFieldBase3GetINS0_16RepeatedPtrFieldIN10tensorflow7NodeDefEE11TypeHandlerEEERKNT_8WeakTypeEi[_ZNK6google8protobuf8internal20RepeatedPtrFieldBase3GetINS0_16RepeatedPtrFieldIN10tensorflow7NodeDefEE11TypeHandlerEEERKNT_8WeakTypeEi]+0xd6): undefined reference to `google::protobuf::internal::LogMessage::~LogMessage()'
main.cpp:(.text._ZNK6google8protobuf8internal20RepeatedPtrFieldBase3GetINS0_16RepeatedPtrFieldIN10tensorflow7NodeDefEE11TypeHandlerEEERKNT_8WeakTypeEi[_ZNK6google8protobuf8internal20RepeatedPtrFieldBase3GetINS0_16RepeatedPtrFieldIN10tensorflow7NodeDefEE11TypeHandlerEEERKNT_8WeakTypeEi]+0xf5): undefined reference to `google::protobuf::internal::LogMessage::~LogMessage()'
collect2: error: ld returned 1 exit status
似乎缺少 protobuf lib。所以我也添加了-ltensorflow_framework
它编译没有错误,但我得到这个错误:
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/descriptor_database.cc:118] File already exists in database: google/protobuf/any.proto
[libprotobuf FATAL external/com_google_protobuf/src/google/protobuf/descriptor.cc:1367] CHECK failed: GeneratedDatabase()->Add(encoded_file_descriptor, size):
terminate called after throwing an instance of 'google::protobuf::FatalException'
what(): CHECK failed: GeneratedDatabase()->Add(encoded_file_descriptor, size):
Aborted (core dumped)
为什么我会收到这些错误?
我已经阅读了这个https://github.com/tensorflow/tensorflow/issues/14632和这个https://github.com/tensorflow/tensorflow/issues/40004,但没有运气。
谢谢
解决方案
推荐阅读
- angular - Angular Reactive Froms,获得原始的嵌套控制
- python - 在 GCP Composer 上创建 Airflow DAG
- python - python-openstackclient issubclass() arg 1 必须是一个类
- android - 在所有活动中实现 Android NavDrawer
- javascript - Chrome 浏览器上的自动文件下载限制为 10 个文件
- javascript - 如何编译/加密 .js 文件?
- java - 解析没有尾数和指数分隔符的双精度
- php - 显示附加结果的 Sql 查询
- angular - Angular 7 httpClient.post 订阅数据在 Angular 模板中不可访问
- java - Hibernate 尝试将数据插入到不存在的表中