tensorflow - Operation ParseExample not supported while converting SavedModel to TFLite
问题描述
I am using TensorFlow estimator to train and save a model and then convert it in a .tflite. I saved the model as following:
feat_cols = [tf.feature_column.numeric_column('feature1'),
tf.feature_column.numeric_column('feature2'),
tf.feature_column.numeric_column('feature3'),
tf.feature_column.numeric_column('feature4')]
def serving_input_receiver_fn():
"""An input receiver that expects a serialized tf.Example."""
feature_spec = tf.feature_column.make_parse_example_spec(feat_cols)
default_batch_size = 1
serialized_tf_example = tf.placeholder(dtype=tf.string, shape=[default_batch_size], name='tf_example')
receiver_tensors = {'examples': serialized_tf_example}
features = tf.parse_example(serialized_tf_example, feature_spec)
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
dnn_regressor.export_saved_model(export_dir_base='model',
serving_input_receiver_fn=serving_input_receiver_fn)
When I try to convert the resulting .pb file using:
tflite_convert --output_file=/tmp/foo.tflite --saved_model_dir=/tmp/saved_model
I get an exception saying that the ParseExample operation is not supported by the TensorFlow Lite.
Some of the operators in the model are not supported by the standard TensorFlow Lite runtime. If those are native TensorFlow operators, you might be able to use the extended runtime by passing --enable_select_tf_ops, or by setting target_ops=TFLITE_BUILTINS,SELECT_TF_OPS when calling tf.lite.TFLiteConverter(). Otherwise, if you have a custom implementation for them you can disable this error with --allow_custom_ops, or by setting allow_custom_ops=True when calling tf.lite.TFLiteConverter(). Here is a list of builtin operators you are using: CONCATENATION, FULLY_CONNECTED, RESHAPE. Here is a list of operators for which you will need custom implementations: ParseExample.
If I try to export the model without serializing, when I try and predict on the resulting .pb file the function expects and empty set(), and not the dict of inputs that I am passing.
ValueError: Got unexpected keys in input_dict: {'feature1', 'feature2', 'feature3', 'feature4'} expected: set()
What am I doing wrong? Here is the code that atempts to save the model without doing any serialization
features = {
'feature1': tf.placeholder(dtype=tf.float32, shape=[1], name='feature1'),
'feature2': tf.placeholder(dtype=tf.float32, shape=[1], name='feature2'),
'feature3': tf.placeholder(dtype=tf.float32, shape=[1], name='feature3'),
'feature4': tf.placeholder(dtype=tf.float32, shape=[1], name='feature4')
}
def serving_input_receiver_fn():
return tf.estimator.export.ServingInputReceiver(features, features)
dnn_regressor.export_savedmodel(export_dir_base='model', serving_input_receiver_fn=serving_input_receiver_fn, as_text=True)
解决方案
解决了
使用 build_raw_serving_input_receiver_fn 我设法在没有任何序列化的情况下导出保存的模型:
serve_input_fun = tf.estimator.export.build_raw_serving_input_receiver_fn(
features,
default_batch_size=None
)
dnn_regressor.export_savedmodel(
export_dir_base="model",
serving_input_receiver_fn=serve_input_fun,
as_text=True
)
注意:在进行预测时,Predictor 不知道默认的 signature_def,所以我需要指定它:
predict_fn = predictor.from_saved_model("model/155482...", signature_def_key="predict")
另外从 .pb 转换为 .tflite 我使用了 Python API,因为我还需要在那里指定 signature_def:
converter = tf.contrib.lite.TFLiteConverter.from_saved_model('model/155482....', signature_key='predict')
推荐阅读
- node.js - antd Form 无法下载后端生成的 PDF(Node、React、Nodemailer)
- python - 为什么请求中不包含授权标头?- 认证0
- haskell - 将任何数值类型作为参数的数据构造函数
- mysql - 从 SQL 中的字符串右侧删除数字
- git - 内核:平分合并提交以查找非合并优先错误
- typescript - Mongo:新的没有得到最新的文件
- docker - 在服务更新期间,Docker swarm 将请求同时重定向到应用程序的“新”和“旧”版本
- javascript - 动态表单:删除与项目有关的选定表单索引已成功从列表中删除项目,但表单仍显示已删除的项目
- next.js - 为什么我收到 CssSyntaxError:
从 PostCSS 构建? - c# - 将json中的日期反序列化为c#中的DateTime对象