python-3.x - 使用 Rest API 查询 tensorflow_model_server 文本分类器返回 Invalid Argument
问题描述
我想运行一个tensorflow
可以用字符串数组查询的服务器,并让它返回每个字符串的分类。
我的估算器如下所示:
def _initialise_estimator():
feature_columns = _create_feature_columns()
# Set run configs. Only keep one checkpoint.
config = tensorflow.estimator.RunConfig(
keep_checkpoint_max = 1,
)
# Create the estimator.
estimator = tensorflow.estimator.DNNClassifier(
hidden_units = [500, 100],
feature_columns = feature_columns,
n_classes = 98,
optimizer = tensorflow.train.AdagradOptimizer(learning_rate = 0.003),
model_dir = _model_location,
config = config
)
return estimator
def _create_feature_columns():
# Feature column. Assumed to be called description.
return [tensorflow_hub.text_embedding_column(
key = 'description',
module_spec = 'https://tfhub.dev/google/nnlm-en-dim128/1'
)]
_estimator = _initialise_estimator()
我可以成功训练模型:
data = pandas.read_csv(_pathToData).dropna()
dataset = data[['description', 'class']].copy()
training_input_function = tensorflow.estimator.inputs.pandas_input_fn(
dataset,
dataset['class'],
num_epochs = None,
shuffle = True
)
_estimator.train(
input_fn = training_input_function,
steps = 1000
)
用它来预测事物:
descriptions = pandas.DataFrame(['sales', 'equipment'], columns=['description'])
predict_input_fn = tensorflow.estimator.inputs.pandas_input_fn(
descriptions,
shuffle = False
)
_estimator.predict(input_fn = predict_input_fn)
并使用以下方法导出模型export_savedmodel
:
feature_columns = _create_feature_columns()
feature_spec = tensorflow.feature_column.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = tensorflow.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)
_estimator.export_savedmodel(
export_dir_base = _export_location,
serving_input_receiver_fn = serving_input_receiver_fn,
strip_default_attrs = True,
as_text = True
)
服务器使用以下命令运行:
tensorflow_model_server --rest_api_port=8501 --model_name=prediction --model_base_path=/app/prediction/lib/export
成功启动服务器。
但是,当我尝试使用以下命令查询服务器时:
curl -d '{"instances": ["sales"]}' -X POST http://localhost:8501/v1/models/prediction:predict
我收到以下错误:
{ "error": "无法解析示例输入,值:\'sales\'\n\t [[{{node ParseExample/ParseExample}} = ParseExample[Ndense=1, Nsparse=0, Tdense=[DT_STRING], _output_shapes=[[?,1]], dense_shapes=[[1]], sparse_types=[], _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](_arg_input_example_tensor_0_0 , ParseExample/ParseExample/names, ParseExample/ParseExample/dense_keys_0, ParseExample/ParseExample/names)]]" }
我究竟做错了什么?有谁知道实例数组应该采用什么格式?我这辈子都无法让它发挥作用。我怀疑这与我如何保存模型有关,export_savedmodel
但我不确定。
非常感谢任何帮助。
解决方案
推荐阅读
- java - 如何在我从字符串中得到的数字后面加上零的“循环”模式?
- qt - 如何将一个插槽的变量声明到另一个插槽
- leaflet - 带有 shadow-cljs 的 react-leaflet 出现不清楚的错误
- reactjs - 如何为 MUI v4 表的备用行设置不同的颜色?
- c++ - 如何正确#include 一个柯南包?
- amazon-web-services - 如何生成指向具有无限 TTL 文件的公共 amazon-s3 链接?
- multithreading - 数据包序列号错误 - 得到 7 预期 2 celery+flask-sqlalchemy
- java - Xamarin 使用 CallRedirectionService 引发 Java.Lang.ClassNotFoundException
- html - VS Code 扩展以纠正 html 自动关闭标签烦人的行为
- corosync - Corosync 仅显示一个节点