首页 > 解决方案 > Operation ParseExample not supported while converting SavedModel to TFLite

问题描述

I am using TensorFlow estimator to train and save a model and then convert it in a .tflite. I saved the model as following:

feat_cols = [tf.feature_column.numeric_column('feature1'),
             tf.feature_column.numeric_column('feature2'),
             tf.feature_column.numeric_column('feature3'),
             tf.feature_column.numeric_column('feature4')]

def serving_input_receiver_fn():
    """An input receiver that expects a serialized tf.Example."""
    feature_spec = tf.feature_column.make_parse_example_spec(feat_cols)
    default_batch_size = 1
    serialized_tf_example = tf.placeholder(dtype=tf.string, shape=[default_batch_size], name='tf_example')
    receiver_tensors = {'examples': serialized_tf_example}
    features = tf.parse_example(serialized_tf_example, feature_spec)
    return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)


dnn_regressor.export_saved_model(export_dir_base='model',
                                 serving_input_receiver_fn=serving_input_receiver_fn)

When I try to convert the resulting .pb file using:

tflite_convert --output_file=/tmp/foo.tflite --saved_model_dir=/tmp/saved_model

I get an exception saying that the ParseExample operation is not supported by the TensorFlow Lite.

Some of the operators in the model are not supported by the standard TensorFlow Lite runtime. If those are native TensorFlow operators, you might be able to use the extended runtime by passing --enable_select_tf_ops, or by setting target_ops=TFLITE_BUILTINS,SELECT_TF_OPS when calling tf.lite.TFLiteConverter(). Otherwise, if you have a custom implementation for them you can disable this error with --allow_custom_ops, or by setting allow_custom_ops=True when calling tf.lite.TFLiteConverter(). Here is a list of builtin operators you are using: CONCATENATION, FULLY_CONNECTED, RESHAPE. Here is a list of operators for which you will need custom implementations: ParseExample.

If I try to export the model without serializing, when I try and predict on the resulting .pb file the function expects and empty set(), and not the dict of inputs that I am passing.

ValueError: Got unexpected keys in input_dict: {'feature1', 'feature2', 'feature3', 'feature4'} expected: set()

What am I doing wrong? Here is the code that atempts to save the model without doing any serialization

features = {
    'feature1': tf.placeholder(dtype=tf.float32, shape=[1], name='feature1'),
    'feature2': tf.placeholder(dtype=tf.float32, shape=[1], name='feature2'),
    'feature3': tf.placeholder(dtype=tf.float32, shape=[1], name='feature3'),
    'feature4': tf.placeholder(dtype=tf.float32, shape=[1], name='feature4')
}

def serving_input_receiver_fn():
    return tf.estimator.export.ServingInputReceiver(features, features)


dnn_regressor.export_savedmodel(export_dir_base='model', serving_input_receiver_fn=serving_input_receiver_fn, as_text=True)

标签: tensorflowtensorflow-estimatortensorflow-lite

解决方案


解决了

使用 build_raw_serving_input_receiver_fn 我设法在没有任何序列化的情况下导出保存的模型:

serve_input_fun = tf.estimator.export.build_raw_serving_input_receiver_fn(
    features,
    default_batch_size=None
)

dnn_regressor.export_savedmodel(
    export_dir_base="model",
    serving_input_receiver_fn=serve_input_fun,
    as_text=True
)

注意:在进行预测时,Predictor 不知道默认的 signature_def,所以我需要指定它:

predict_fn = predictor.from_saved_model("model/155482...", signature_def_key="predict")

另外从 .pb 转换为 .tflite 我使用了 Python API,因为我还需要在那里指定 signature_def:

converter = tf.contrib.lite.TFLiteConverter.from_saved_model('model/155482....', signature_key='predict')

推荐阅读