c++ - C++ 中仅 gRPC 的 Tensorflow Serving 客户端
问题描述
似乎有一些信息可以用于gRPC
在 Python(甚至一些其他语言)中创建一个 -only 客户端,并且我能够成功地获得一个仅gRPC
在 Python 中使用且适用于我们实现的工作客户端。
我似乎找不到有人用 C++ 成功编写客户端的情况。
任务的约束如下:
- 构建系统不能
bazel
,因为最终的应用程序已经有自己的构建系统。 - 客户端不能包含
Tensorflow
(需要bazel
在 C++ 中构建)。 - 应用程序应该使用 gRPC 而不是 HTTP 调用以提高速度。
- 理想情况下,应用程序不会调用 Python 或执行 shell 命令。
鉴于上述限制,并假设我提取并生成了gRPC
存根,这是否可能?如果是这样,可以提供一个例子吗?
解决方案
事实证明,如果您已经在 Python 中完成了它,这并不是什么新鲜事。假设模型被命名为“predict”并且模型的输入被称为“inputs”,下面是 Python 代码:
import logging
import grpc
from grpc import RpcError
from types_pb2 import DT_FLOAT
from tensor_pb2 import TensorProto
from tensor_shape_pb2 import TensorShapeProto
from predict_pb2 import PredictRequest
from prediction_service_pb2_grpc import PredictionServiceStub
class ModelClient:
"""Client Facade to work with a Tensorflow Serving gRPC API"""
host = None
port = None
chan = None
stub = None
logger = logging.getLogger(__name__)
def __init__(self, name, dims, dtype=DT_FLOAT, version=1):
self.model = name
self.dims = [TensorShapeProto.Dim(size=dim) for dim in dims]
self.dtype = dtype
self.version = version
@property
def hostport(self):
"""A host:port string representation"""
return f"{self.host}:{self.port}"
def connect(self, host='localhost', port=8500):
"""Connect to the gRPC server and initialize prediction stub"""
self.host = host
self.port = int(port)
self.logger.info(f"Connecting to {self.hostport}...")
self.chan = grpc.insecure_channel(self.hostport)
self.logger.info("Initializing prediction gRPC stub.")
self.stub = PredictionServiceStub(self.chan)
def tensor_proto_from_measurement(self, measurement):
"""Pass in a measurement and return a tensor_proto protobuf object"""
self.logger.info("Assembling measurement tensor.")
return TensorProto(
dtype=self.dtype,
tensor_shape=TensorShapeProto(dim=self.dims),
string_val=[bytes(measurement)]
)
def predict(self, measurement, timeout=10):
"""Execute prediction against TF Serving service"""
if self.host is None or self.port is None \
or self.chan is None or self.stub is None:
self.connect()
self.logger.info("Creating request.")
request = PredictRequest()
request.model_spec.name = self.model
if self.version > 0:
request.model_spec.version.value = self.version
request.inputs['inputs'].CopyFrom(
self.tensor_proto_from_measurement(measurement))
self.logger.info("Attempting to predict against TF Serving API.")
try:
return self.stub.Predict(request, timeout=timeout)
except RpcError as err:
self.logger.error(err)
self.logger.error('Predict failed.')
return None
以下是一个有效的(粗略的)C++ 翻译:
#include <iostream>
#include <memory>
#include <string>
#include <grpcpp/grpcpp.h>
#include "grpcpp/create_channel.h"
#include "grpcpp/security/credentials.h"
#include "google/protobuf/map.h"
#include "types.grpc.pb.h"
#include "tensor.grpc.pb.h"
#include "tensor_shape.grpc.pb.h"
#include "predict.grpc.pb.h"
#include "prediction_service.grpc.pb.h"
using grpc::Channel;
using grpc::ClientContext;
using grpc::Status;
using tensorflow::TensorProto;
using tensorflow::TensorShapeProto;
using tensorflow::serving::PredictRequest;
using tensorflow::serving::PredictResponse;
using tensorflow::serving::PredictionService;
typedef google::protobuf::Map<std::string, tensorflow::TensorProto> OutMap;
class ServingClient {
public:
ServingClient(std::shared_ptr<Channel> channel)
: stub_(PredictionService::NewStub(channel)) {}
// Assembles the client's payload, sends it and presents the response back
// from the server.
std::string callPredict(const std::string& model_name,
const float& measurement) {
// Data we are sending to the server.
PredictRequest request;
request.mutable_model_spec()->set_name(model_name);
// Container for the data we expect from the server.
PredictResponse response;
// Context for the client. It could be used to convey extra information to
// the server and/or tweak certain RPC behaviors.
ClientContext context;
google::protobuf::Map<std::string, tensorflow::TensorProto>& inputs =
*request.mutable_inputs();
tensorflow::TensorProto proto;
proto.set_dtype(tensorflow::DataType::DT_FLOAT);
proto.add_float_val(measurement);
proto.mutable_tensor_shape()->add_dim()->set_size(5);
proto.mutable_tensor_shape()->add_dim()->set_size(8);
proto.mutable_tensor_shape()->add_dim()->set_size(105);
inputs["inputs"] = proto;
// The actual RPC.
Status status = stub_->Predict(&context, request, &response);
// Act upon its status.
if (status.ok()) {
std::cout << "call predict ok" << std::endl;
std::cout << "outputs size is " << response.outputs_size() << std::endl;
OutMap& map_outputs = *response.mutable_outputs();
OutMap::iterator iter;
int output_index = 0;
for (iter = map_outputs.begin(); iter != map_outputs.end(); ++iter) {
tensorflow::TensorProto& result_tensor_proto = iter->second;
std::string section = iter->first;
std::cout << std::endl << section << ":" << std::endl;
if ("classes" == section) {
int titer;
for (titer = 0; titer != result_tensor_proto.int64_val_size(); ++titer) {
std::cout << result_tensor_proto.int64_val(titer) << ", ";
}
} else if ("scores" == section) {
int titer;
for (titer = 0; titer != result_tensor_proto.float_val_size(); ++titer) {
std::cout << result_tensor_proto.float_val(titer) << ", ";
}
}
std::cout << std::endl;
++output_index;
}
return "Done.";
} else {
std::cout << "gRPC call return code: " << status.error_code() << ": "
<< status.error_message() << std::endl;
return "RPC failed";
}
}
private:
std::unique_ptr<PredictionService::Stub> stub_;
};
请注意,此处的尺寸已在代码中指定,而不是传入。
给定上面的类,执行如下:
int main(int argc, char** argv) {
float measurement[5*8*105] = { ... data ... };
ServingClient sclient(grpc::CreateChannel(
"localhost:8500", grpc::InsecureChannelCredentials()));
std::string model("predict");
std::string reply = sclient.callPredict(model, *measurement);
std::cout << "Predict received: " << reply << std::endl;
return 0;
}
used 是从C++ 示例Makefile
中借用的,变量集相对于 Makefile 和以下构建目标(假设 C++ 应用程序名为):gRPC
PROTOS_PATH
predict.cc
predict: types.pb.o types.grpc.pb.o tensor_shape.pb.o tensor_shape.grpc.pb.o resource_handle.pb.o resource_handle.grpc.pb.o model.pb.o model.grpc.pb.o tensor.pb.o tensor.grpc.pb.o predict.pb.o predict.grpc.pb.o prediction_service.pb.o prediction_service.grpc.pb.o predict.o
$(CXX) $^ $(LDFLAGS) -o $@
推荐阅读
- google-cloud-platform - 使用服务帐户登录谷歌计算引擎
- exception - 是否可能出现“Try / except ValueError UNLESS”?
- ubuntu-16.04 - 如何恢复加载的 Kate 插件
- android - 错误的 Android Studio 所有属性面板
- apache-spark - Spark 超时时的 Hive
- python - 英语缩略正则表达式的解决方案不起作用
- corda - 无法将 cordapp-example 导入 IntelliJ 181.4892.42
- c - C 使用 execvp 执行带有另一个命令输出的命令
- excel - 复制所有行,粘贴到另一个工作表交替行与目标工作表上的现有数据)
- google-chrome-extension - 如何仅通过单击图标或按钮来关闭 chrome 扩展