首页 > 解决方案 > 使用 Rest API 查询 tensorflow_model_server 文本分类器返回 Invalid Argument

问题描述

我想运行一个tensorflow可以用字符串数组查询的服务器,并让它返回每个字符串的分类。

我的估算器如下所示:

def _initialise_estimator():

   feature_columns = _create_feature_columns()

   # Set run configs. Only keep one checkpoint.
   config = tensorflow.estimator.RunConfig(
      keep_checkpoint_max = 1,
   )

   # Create the estimator.
   estimator = tensorflow.estimator.DNNClassifier(
      hidden_units = [500, 100],
      feature_columns = feature_columns,
      n_classes = 98,
      optimizer = tensorflow.train.AdagradOptimizer(learning_rate = 0.003),
      model_dir = _model_location,
      config = config
   )

   return estimator


def _create_feature_columns():
   # Feature column. Assumed to be called description.
   return [tensorflow_hub.text_embedding_column(
      key = 'description',
      module_spec = 'https://tfhub.dev/google/nnlm-en-dim128/1'
   )]

_estimator = _initialise_estimator()

我可以成功训练模型:

data = pandas.read_csv(_pathToData).dropna()
dataset = data[['description', 'class']].copy()

training_input_function = tensorflow.estimator.inputs.pandas_input_fn(
   dataset,
   dataset['class'],
   num_epochs = None,
   shuffle = True
)

_estimator.train(
   input_fn = training_input_function,
   steps = 1000
)

用它来预测事物:

descriptions = pandas.DataFrame(['sales', 'equipment'], columns=['description'])

predict_input_fn = tensorflow.estimator.inputs.pandas_input_fn(
   descriptions,
   shuffle = False
)

_estimator.predict(input_fn = predict_input_fn)

并使用以下方法导出模型export_savedmodel

feature_columns = _create_feature_columns()
feature_spec = tensorflow.feature_column.make_parse_example_spec(feature_columns)

serving_input_receiver_fn = tensorflow.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)

_estimator.export_savedmodel(
   export_dir_base = _export_location,
   serving_input_receiver_fn = serving_input_receiver_fn,
   strip_default_attrs = True,
   as_text = True
)

服务器使用以下命令运行:

tensorflow_model_server --rest_api_port=8501 --model_name=prediction --model_base_path=/app/prediction/lib/export

成功启动服务器。

但是,当我尝试使用以下命令查询服务器时:

curl -d '{"instances": ["sales"]}' -X POST http://localhost:8501/v1/models/prediction:predict

我收到以下错误:

{ "error": "无法解析示例输入,值:\'sales\'\n\t [[{{node ParseExample/ParseExample}} = ParseExample[Ndense=1, Nsparse=0, Tdense=[DT_STRING], _output_shapes=[[?,1]], dense_shapes=[[1]], sparse_types=[], _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](_arg_input_example_tensor_0_0 , ParseExample/ParseExample/names, ParseExample/ParseExample/dense_keys_0, ParseExample/ParseExample/names)]]" }

我究竟做错了什么?有谁知道实例数组应该采用什么格式?我这辈子都无法让它发挥作用。我怀疑这与我如何保存模型有关,export_savedmodel但我不确定。

非常感谢任何帮助。

标签: python-3.xtensorflowtensorflow-servingtensorflow-estimator

解决方案


推荐阅读