首页 > 解决方案 > 没有参数的 TensorFlow Serving 导出签名

问题描述

我想为 SavadModel 添加额外的签名,它将返回业务描述并使用 TensorFlow Serving 提供服务。

@tf.function
def info():
    return json.dumps({
       'name':  'My model',
       'description': 'This is model description.',
       'project': 'Product ABCD',
       'type': 'some_type',
       ...
})

正如 TensorFlow Core 手册https://www.tensorflow.org/guide/saved_model#identifying_a_signature_to_export中所写,我可以轻松导出接受提供 tf.TensorSpec 的参数的签名。

是否可以在没有参数的情况下导出签名并在服务器上调用它?


在@EricMcLachlan 评论后添加:

当我尝试使用如下代码调用没有定义签名(input_signature=[])的函数时:

data = json.dumps({"signature_name": "info", "inputs": None})

headers = {"content-type": "application/json"}
json_response = requests.post('http://localhost:8501/v1/models/my_model:predict', data=data, headers=headers)

我在响应中得到下一个错误:

'_content': b'{ "error": "未能获取签名的输入映射:信息" }'

标签: tensorflowtensorflow-servingtfx

解决方案


定义签名:

我打算写我自己的例子,但这是@AntPhitlok 在另一个StackOverflow 帖子中提供的一个很好的例子:

class MyModule(tf.Module):
  def __init__(self, model, other_variable):
    self.model = model
    self._other_variable = other_variable

  @tf.function(input_signature=[tf.TensorSpec(shape=(None, None, 1), dtype=tf.float32)])
  def score(self, waveform):
    result = self.model(waveform)
    return { "scores": results }

  @tf.function(input_signature=[])
  def metadata(self):
    return { "other_variable": self._other_variable }

在这种情况下,他们服务的是一个模块,但它也可能是一个 Keras 模型。


使用服务:

我不是 100% 确定如何访问服务(我自己还没有完成),但我认为您将能够访问与此类似的服务:

from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
request = predict_pb2.PredictRequest()
request.model_spec.name = model_name
request.model_spec.signature_name = 'serving_default'
request.model_spec.version_label = self.version

tensor_proto = tf.make_tensor_proto(my_input_data, dtype=tf.float32)
request.inputs['my_signatures_input'].CopyFrom(tensor_proto)

try:
    response = self.stub.Predict(request, MAX_TIMEOUT)
except Exception as ex:
    logging.error(str(ex))
    return [None] * len(batch_of_texts)

这里我使用 gRPC 来访问 TensorFlow Server。

您可能需要用您的服务名称替换“serving_default”。同样,“my_signature_input”应该与您的输入相匹配tf.function(在您的情况下,我认为它是空的)。

这是一个正常的标准 Keras 类型预测,是 predict_pb2.PredictRequest 的附带条件。可能有必要创建一个自定义 Protobuf,但这有点超出了我的能力。

我希望这足以让你继续前进。


推荐阅读