tensorflow - 如何将签名名称设置为 serving_default 以避免部署到 GCP AI 平台后出错
问题描述
我在 AI Platform 上使用 Tensorflow 训练了一个 DNN 模型。然后我在本地复制模型以仔细检查是否可以从相同的模型中获得预测。
gcloud ai-platform local predict --model-dir=/home/jupyter/end-to-end-ml/examples/e2e-ml-model-ex02/app/appbabyweight_trained/export/exporter/1615197796 --json-instances=inputs.json
获得一些警告的预测。
If the signature defined in the model is not `serving_default` then you must specify it via --signature-name flag, otherwise the command may fail.
(在指定签名名称时可以避免此警告--signature-name predict
:)
将模型部署到 AI Platform 后,警告变为错误。服务签名名称必须serving_default
如以下错误消息中所示:
{“错误”:“服务签名名称:“服务默认”未在签名定义中找到“}
使用此命令检查保存的模型后:
saved_model_cli show --dir /home/jupyter/end-to-end-ml/examples/e2e-ml-model-ex02/app/appbabyweight_trained2/output-dir/export/exporter/1615439076 --all
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['predict']:
The given SavedModel SignatureDef contains the following input(s):
inputs['gestation_weeks'] tensor_info:
dtype: DT_FLOAT
shape: (-1)
name: Placeholder_3:0
inputs['is_male'] tensor_info:
dtype: DT_STRING
shape: (-1)
name: Placeholder:0
inputs['mother_age'] tensor_info:
dtype: DT_FLOAT
shape: (-1)
name: Placeholder_1:0
inputs['plurality'] tensor_info:
dtype: DT_STRING
shape: (-1)
name: Placeholder_2:0
The given SavedModel SignatureDef contains the following output(s):
outputs['predictions'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: dnn/logits/BiasAdd:0
Method name is: tensorflow/serving/predict
所以,我保存的模型的签名名称是 then predict
。
问题是:如何更改签名名称?
PS:下面是我定义 DNN 的方式:
# Define feature columns
def get_categorical(name, values):
return tf.feature_column.indicator_column(
tf.feature_column.categorical_column_with_vocabulary_list(name, values))
def get_cols():
# Define column types
return [\
get_categorical('is_male', ['True', 'False', 'Unknown']),
tf.feature_column.numeric_column('mother_age'),
get_categorical('plurality',
['Single(1)', 'Twins(2)', 'Triplets(3)',
'Quadruplets(4)', 'Quintuplets(5)','Multiple(2+)']),
tf.feature_column.numeric_column('gestation_weeks')
]
# Create serving input function to be able to serve predictions later using provided inputs
def serving_input_fn():
feature_placeholders = {
'is_male': tf.compat.v1.placeholder(tf.string, [None]),
'mother_age': tf.compat.v1.placeholder(tf.float32, [None]),
'plurality': tf.compat.v1.placeholder(tf.string, [None]),
'gestation_weeks': tf.compat.v1.placeholder(tf.float32, [None])
}
features = {
key: tf.expand_dims(tensor, -1) for key, tensor in feature_placeholders.items()
}
return tf.estimator.export.ServingInputReceiver(features, feature_placeholders)
# Create estimator to train and evaluate
def train_and_evaluate(args):
EVAL_INTERVAL = 30
run_config = tf.estimator.RunConfig(save_checkpoints_secs = EVAL_INTERVAL, keep_checkpoint_max = 3)
estimator = tf.estimator.DNNRegressor(
model_dir = args['output_dir'],
feature_columns = get_cols(),
hidden_units = args['nnsize'],
config = run_config)
train_spec = tf.estimator.TrainSpec(
input_fn = read_dataset(args['train_data_path'],
mode = tf.estimator.ModeKeys.TRAIN,
batch_size =args['batch_size']),
max_steps = TRAIN_STEPS)
exporter = tf.estimator.LatestExporter('exporter', serving_input_fn)
eval_spec = tf.estimator.EvalSpec(
input_fn = read_dataset(args['eval_data_path'], mode = tf.estimator.ModeKeys.EVAL, batch_size =args['batch_size']),
steps = args['eval_steps'],
start_delay_secs = 60, # start evaluating after N seconds
throttle_secs = EVAL_INTERVAL, # evaluate every N seconds
exporters = exporter)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
谢谢
解决方案
将签名 serving_default 添加到现有的 saved_model
import tensorflow as tf
m = tf.saved_model.load("tf2-preview_inception_v3_classification_4")
print(m.signatures) # _SignatureMap({}) - Empty
t_spec = tf.TensorSpec([None,None,None,3], tf.float32)
c_func = m.__call__.get_concrete_function(inputs=t_spec)
signatures = {'serving_default': c_func}
tf.saved_model.save(m, 'tf2-preview_inception_v3_classification_5', signatures=signatures)
# Test new model
m5 = tf.saved_model.load("tf2-preview_inception_v3_classification_5")
print(m5.signatures) # _SignatureMap({'serving_default': <ConcreteFunction signature_wrapper(*, inputs) at 0x17316DC50>})
推荐阅读
- r - 如何更改嵌套数据框的名称
- java - 在java中使用流程构建器替换参数
- python - 如何使用 scipy 和 matplotlib 拟合反向 sigmoid 函数
- java - 如何对多维Arraylist进行排序》ArrayList
“在Java中? - c# - 如何获取表格的全部详细信息?
- javascript - 如果数字为 0,则将数字转换为字符串返回空
- c++ - Process_vm_readv 返回随机值
- node.js - Chartjs 导出不带html的图表
- ios - 此正则表达式在 BBEdit 和 regex.com 中匹配,但在 iOS 上不匹配 - 为什么?
- java - 如何使用 qaf bdd 编辑器在 Eclipse 中运行单个场景?