首页 > 解决方案 > TF2 SavedModel 修剪和冻结

问题描述

使用 TF2.2.rc3 我有一个 SavedModel 对象从一个估计器通过以下方式生成:

def serving_fn():
   inputs={}
   return tf.estimator.export.ServingInputReceiver(inputs, inputs)

然后我用export_path = model.export_saved_model(export_dir, serving_fn).

然后我想优化这个模型,所以我这样做(根据这个答案):

imported = tf.saved_model.load(export_dir)
pruned = imported.prune(input_node,output_node)

from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
frozen_func = convert_variables_to_constants_v2(pruned)
class Exportable(tf.Module):
       @tf.function
       def __call__(self, model_inputs): return frozen_func(model_inputs,tf.ones([],dtype=tf.dtypes.float32)) 
       # the second input is to satisfy the global_step tensor from the estimator input
       def call(self, model_inputs): return frozen_func(model_inputs,tf.ones([],dtype=tf.dtypes.float32)) 
       # created this to attempt to fix the error

svmod2_export = Exportable()

svmod2_export(tf.ones(dummy_input_shape,dtype=tf.as_dtype(dummy_input_dtype)))
tf.saved_model.save(svmod2_export,'frozen_savedmodel/')

但是,当我尝试:

from tensorflow.python.keras.saving import saving_utils as _saving_utils
model = tf.keras.models.load_model(filename)
tf.keras.backend.set_learning_phase(False)func = _saving_utils.trace_model_call(model)

我得到错误:

  File "/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/saving/saving_utils.py", line 113, in trace_model_call
    if isinstance(model.call, def_function.Function):
AttributeError: '_UserObject' object has no attribute 'call'

使用 CLI 检查原始保存的模型:

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['feature'] tensor_info:
        dtype: DT_INT32
        shape: (-1)
        name: Placeholder:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['logits'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 99)
        name: BiasAdd:0
  Method name is: tensorflow/serving/predict

当我尝试使用已保存模型 CLI 检查新的已保存模型时:

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is:

Defined Functions:
  Function Name: '__call__'
    Option #1
      Callable with:
        Argument #1
          model_inputs: TensorSpec(shape=(1,), dtype=tf.int32, name='model_inputs')

  Function Name: 'call'

加载模型并尝试通过它运行一个值会出现以下错误(尽管我要解决的关键问题是trace_model_call()一个):

>>> m = tf.saved_model.load('.')
2020-04-28 18:47:39.012133: I tensorflow/core/platform/cpu_feature_guard.cc:143] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-04-28 18:47:39.032677: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f9a59d20d40 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-04-28 18:47:39.032701: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
>>> m(tf.constant(1))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py", line 486, in _call_attribute
    return instance.__call__(*args, **kwargs)
  File "/usr/local/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 580, in __call__
    result = self._call(*args, **kwds)
  File "/usr/local/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 627, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/usr/local/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 506, in _initialize
    *args, **kwds))
  File "/usr/local/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2446, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/usr/local/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2777, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/usr/local/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2667, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/usr/local/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 981, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/usr/local/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 441, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/usr/local/lib/python3.7/site-packages/tensorflow/python/saved_model/function_deserialization.py", line 261, in restored_function_body
    "\n\n".join(signature_descriptions)))
ValueError: Could not find matching function to call loaded from the SavedModel. Got:
  Positional arguments (1 total):
    * Tensor("model_inputs:0", shape=(), dtype=int32)
  Keyword arguments: {}

Expected these arguments to match one of the following 1 option(s):

Option 1:
  Positional arguments (1 total):
    * TensorSpec(shape=(1,), dtype=tf.int32, name='model_inputs')
  Keyword arguments: {}

标签: pythontensorflowkeras

解决方案


推荐阅读