首页 > 解决方案 > tf.keras.estimator.model_to_estimator 无法使用自定义和 Lambda 层转换 keras 模型

问题描述

前段时间我编写了一个模型,它使用了一些自定义层定义,并使用 TF 1.12 和独立的 Keras 2.2.4 进行了训练。我已将 TF 的版本更新到 1.14 并切换到 tf.keras。使用自定义加载函数,我的模型构建、加载权重并生成预测。

现在,我正在尝试将我的 keras 模型转换为可用于推理的 TF Estimator,但我遇到了各种各样的问题。我相信它源于我的 Lambda 层中的 get_config() 方法。我目前这样定义它们:

class NamedLambda(Lambda):
    def __init__(self, name=None):
        Lambda.__init__(self, self.fn, name=name)

    @classmethod
    def invoke(cls, args, **kw):
        return cls(**kw)(args)

    def __repr__(self):
        return '%s(%s)' % (self.__class__.__name__, self.name)

class L2Normalize(NamedLambda):
    def fn(self, x):
        return K.l2_normalize(x, axis=-1)

当我检查时, get_config 方法工作得很好:

custom_objects['l2_normalize'].get_config()
{'arguments': DictWrapper({}),
 'dtype': 'float32',
 'function': 'fn',
 'function_type': 'function',
 'module': 'grademachine.utils',
 'name': 'l2_normalize',
 'output_shape': None,
 'output_shape_module': None,
 'output_shape_type': 'raw',
 'trainable': True}

下面是一些示例代码和让我难过的回溯。任何帮助将非常感激。

model = load_model(model_dir, 
                   options_fn='model123_options', 
                   weights_fn='model123_weights')
model
<tensorflow.python.keras.engine.training.Model at 0x7fe3d43d8e10>
est = tf.keras.estimator.model_to_estimator(keras_model=model)

我还尝试按如下方式添加我的自定义层,这会产生稍微不同的回溯,但最终会出现在同一个地方。下面的回溯来自定义了 custom_objects 的版本:

# custom_layer_names is a list of names of each of the custom layers in the trained model
custom_objects = {l.name: l for l in model.layers if l.name in custom_layer_names}
est = tf.keras.estimator.model_to_estimator(keras_model=model,  
                                            custom_objects=custom_objects)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpyujm6s99
INFO:tensorflow:Using the Keras model provided.
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-13-512a382c338c> in <module>()
     13 est = tf.keras.estimator.model_to_estimator(keras_model=model, 
     14                                             model_dir='saved_estimator/',
---> 15                                             custom_objects=custom_objects)

~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow/python/keras/estimator/__init__.py in model_to_estimator(keras_model, keras_model_path, custom_objects, model_dir, config)
     71       custom_objects=custom_objects,
     72       model_dir=model_dir,
---> 73       config=config)
     74 
     75 # LINT.ThenChange(//tensorflow_estimator/python/estimator/keras.py)

~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py in model_to_estimator(keras_model, keras_model_path, custom_objects, model_dir, config)
    448   if keras_model._is_graph_network:
    449     warm_start_path = _save_first_checkpoint(keras_model, custom_objects,
--> 450                                              config)
    451   elif keras_model.built:
    452     logging.warning('You are creating an Estimator from a Keras model manually '

~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py in _save_first_checkpoint(keras_model, custom_objects, config)
    316       training_util.create_global_step()
    317       model = _clone_and_build_model(ModeKeys.TRAIN, keras_model,
--> 318                                      custom_objects)
    319       # save to checkpoint
    320       with session.Session(config=config.session_config) as sess:

~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py in _clone_and_build_model(mode, keras_model, custom_objects, features, labels)
    199       compile_clone=compile_clone,
    200       in_place_reset=(not keras_model._is_graph_network),
--> 201       optimizer_iterations=global_step)
    202 
    203   return clone

~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow/python/keras/models.py in clone_and_build_model(model, input_tensors, target_tensors, custom_objects, compile_clone, in_place_reset, optimizer_iterations, optimizer_config)
    534     if custom_objects:
    535       with CustomObjectScope(custom_objects):
--> 536         clone = clone_model(model, input_tensors=input_tensors)
    537     else:
    538       clone = clone_model(model, input_tensors=input_tensors)

~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow/python/keras/models.py in clone_model(model, input_tensors, clone_function)
    324   else:
    325     return _clone_functional_model(
--> 326         model, input_tensors=input_tensors, layer_fn=clone_function)
    327 
    328 

~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow/python/keras/models.py in _clone_functional_model(model, input_tensors, layer_fn)
    152       # Get or create layer.
    153       if layer not in layer_map:
--> 154         new_layer = layer_fn(layer)
    155         layer_map[layer] = new_layer
    156         layer = new_layer

~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow/python/keras/models.py in _clone_layer(layer)
     52 
     53 def _clone_layer(layer):
---> 54   return layer.__class__.from_config(layer.get_config())
     55 
     56 

~/repos/grademachine/grademachine/utils.py in from_config(cls, config, custom_objects)
    850     config = config.copy()
    851     function = cls._parse_function_from_config(
--> 852         config, custom_objects, 'function', 'module', 'function_type')
    853 
    854     output_shape = cls._parse_function_from_config(

~/repos/grademachine/grademachine/utils.py in _parse_function_from_config(cls, config, custom_objects, func_attr_name, module_attr_name, func_type_attr_name)
    898           config[func_attr_name],
    899           custom_objects=custom_objects,
--> 900           printable_module_name='function in Lambda layer')
    901     elif function_type == 'lambda':
    902       # Unsafe deserialization from bytecode

~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    207       obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
    208     else:
--> 209       obj = module_objects.get(object_name)
    210       if obj is None:
    211         raise ValueError('Unknown ' + printable_module_name + ':' + object_name)

AttributeError: 'NoneType' object has no attribute 'get'

标签: tensorflowkeraspython-3.6tensorflow-estimatortf.keras

解决方案


推荐阅读