首页 > 解决方案 > Keras 在加载 lambda 层时引发异常

问题描述

在 TensorFlow 2.3.1 中保存和加载模型...

import tensorflow as tf

model1 = tf.keras.Sequential([
    tf.keras.layers.Input(shape = (81,), dtype = 'uint8'), 
    tf.keras.layers.Lambda(tf.keras.backend.one_hot, arguments={'num_classes': 10}, output_shape=(81, 10)),
])

tf.keras.models.save_model(model1, './model')
model2 = tf.keras.models.load_model('./model')
model2.summary()

导致一个相当长的异常,下面全文引用。似乎是导致问题的 Lambda 层;没有它,保存加载工作正常。我已经尝试custom_objects={'one_hot' : tf.keras.backend.one_hot}在 load 调用中添加等,但它并没有解决它。如果有人对此有解决方法,我将不胜感激。

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/backend.py in wrapper(*args, **kwargs)
    200     try:
--> 201       return target(*args, **kwargs)
    202     except (TypeError, ValueError):

TypeError: 'str' object is not callable

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
<ipython-input-20-77668ee9d9b2> in <module>
      7 
      8 tf.keras.models.save_model(model1, './model')
----> 9 model2 = tf.keras.models.load_model('./model', custom_objects={'tf.keras.backend.one_hot' : tf.keras.backend.one_hot})
     10 model2.summary()

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/save.py in load_model(filepath, custom_objects, compile, options)
    185     if isinstance(filepath, six.string_types):
    186       loader_impl.parse_saved_model(filepath)
--> 187       return saved_model_load.load(filepath, compile, options)
    188 
    189   raise IOError(

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py in load(path, compile, options)
    119 
    120   model = tf_load.load_internal(
--> 121       path, options=options, loader_cls=KerasObjectLoader)
    122 
    123   # pylint: disable=protected-access

/opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py in load_internal(export_dir, tags, options, loader_cls)
    631       try:
    632         loader = loader_cls(object_graph_proto, saved_model_proto, export_dir,
--> 633                             ckpt_options)
    634       except errors.NotFoundError as err:
    635         raise FileNotFoundError(

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py in __init__(self, *args, **kwargs)
    192     self._models_to_reconstruct = []
    193 
--> 194     super(KerasObjectLoader, self).__init__(*args, **kwargs)
    195 
    196     # Now that the node object has been fully loaded, and the checkpoint has

/opt/conda/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py in __init__(self, object_graph_proto, saved_model_proto, export_dir, ckpt_options)
    128       self._concrete_functions[name] = _WrapperFunction(concrete_function)
    129 
--> 130     self._load_all()
    131     self._restore_checkpoint()
    132 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py in _load_all(self)
    219 
    220     # Finish setting up layers and models. See function docstring for more info.
--> 221     self._finalize_objects()
    222 
    223   @property

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py in _finalize_objects(self)
    528 
    529     # Initialize graph networks, now that layer dependencies have been resolved.
--> 530     self._reconstruct_all_models()
    531 
    532   def _unblock_model_reconstruction(self, layer_id, layer):

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py in _reconstruct_all_models(self)
    546       all_initialized_models.add(model_id)
    547       model, layers = self.model_layer_dependencies[model_id]
--> 548       self._reconstruct_model(model_id, model, layers)
    549       self._add_object_graph_edges(self._proto.nodes[model_id], model_id)
    550       _finalize_config_layers([model])

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py in _reconstruct_model(self, model_id, model, layers)
    576               dtype=layers[0].dtype,
    577               name=layers[0].name + '_input'))
--> 578       model.__init__(layers, name=config['name'])
    579       if not model.inputs:
    580         first_layer = self._get_child_layer_node_ids(model_id, model.name)[0]

/opt/conda/lib/python3.7/site-packages/tensorflow/python/training/tracking/base.py in _method_wrapper(self, *args, **kwargs)
    455     self._self_setattr_tracking = False  # pylint: disable=protected-access
    456     try:
--> 457       result = method(self, *args, **kwargs)
    458     finally:
    459       self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/sequential.py in __init__(self, layers, name)
    140         layers = [layers]
    141       for layer in layers:
--> 142         self.add(layer)
    143 
    144   @property

/opt/conda/lib/python3.7/site-packages/tensorflow/python/training/tracking/base.py in _method_wrapper(self, *args, **kwargs)
    455     self._self_setattr_tracking = False  # pylint: disable=protected-access
    456     try:
--> 457       result = method(self, *args, **kwargs)
    458     finally:
    459       self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/sequential.py in add(self, layer)
    219       # If the model is being built continuously on top of an input layer:
    220       # refresh its output.
--> 221       output_tensor = layer(self.outputs[0])
    222       if len(nest.flatten(output_tensor)) != 1:
    223         raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
    924     if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
    925       return self._functional_construction_call(inputs, args, kwargs,
--> 926                                                 input_list)
    927 
    928     # Maintains info about the `Layer.call` stack.

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py in _functional_construction_call(self, inputs, args, kwargs, input_list)
   1115           try:
   1116             with ops.enable_auto_cast_variables(self._compute_dtype_object):
-> 1117               outputs = call_fn(cast_inputs, *args, **kwargs)
   1118 
   1119           except errors.OperatorNotAllowedInGraphError as e:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/layers/core.py in call(self, inputs, mask, training)
    901     with backprop.GradientTape(watch_accessed_variables=True) as tape,\
    902         variable_scope.variable_creator_scope(_variable_creator):
--> 903       result = self.function(inputs, **kwargs)
    904     self._check_variables(created_variables, tape.watched_variables())
    905     return result

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/backend.py in wrapper(*args, **kwargs)
    203       # Note: convert_to_eager_tensor currently raises a ValueError, not a
    204       # TypeError, when given unexpected types.  So we need to catch both.
--> 205       result = dispatch(wrapper, args, kwargs)
    206       if result is not OpDispatcher.NOT_SUPPORTED:
    207         return result

TypeError: 'module' object is not callable

标签: pythontensorflowexceptionkeras

解决方案


你也许可以使用tf.keras.layers.experimental.preprocessing.CategoryEncoding图层。

import tensorflow as tf

model1 = tf.keras.Sequential([
    tf.keras.layers.Input(shape = (81,), dtype=tf.int32), 
    tf.keras.layers.experimental.preprocessing.CategoryEncoding(
        max_tokens=(10), output_mode='count'), # or 'binary'
])

tf.keras.models.save_model(model1, './model')
model2 = tf.keras.models.load_model('./model')
model2.summary()

inp = tf.random.uniform((1, 81), 0, 10, dtype=tf.int32)

model2(inp)
<tf.Tensor: shape=(1, 10), dtype=float32, 
    numpy=array([[ 7., 11.,  5.,  6., 10.,  5.,  5., 12., 10., 10.]], dtype=float32)>

推荐阅读