tensorflow - tf.keras.estimator.model_to_estimator 无法使用自定义和 Lambda 层转换 keras 模型
问题描述
前段时间我编写了一个模型,它使用了一些自定义层定义,并使用 TF 1.12 和独立的 Keras 2.2.4 进行了训练。我已将 TF 的版本更新到 1.14 并切换到 tf.keras。使用自定义加载函数,我的模型构建、加载权重并生成预测。
现在,我正在尝试将我的 keras 模型转换为可用于推理的 TF Estimator,但我遇到了各种各样的问题。我相信它源于我的 Lambda 层中的 get_config() 方法。我目前这样定义它们:
class NamedLambda(Lambda):
def __init__(self, name=None):
Lambda.__init__(self, self.fn, name=name)
@classmethod
def invoke(cls, args, **kw):
return cls(**kw)(args)
def __repr__(self):
return '%s(%s)' % (self.__class__.__name__, self.name)
class L2Normalize(NamedLambda):
def fn(self, x):
return K.l2_normalize(x, axis=-1)
当我检查时, get_config 方法工作得很好:
custom_objects['l2_normalize'].get_config()
{'arguments': DictWrapper({}),
'dtype': 'float32',
'function': 'fn',
'function_type': 'function',
'module': 'grademachine.utils',
'name': 'l2_normalize',
'output_shape': None,
'output_shape_module': None,
'output_shape_type': 'raw',
'trainable': True}
下面是一些示例代码和让我难过的回溯。任何帮助将非常感激。
- Python版本:3.6.2
- TensorFlow 版本:1.14.0
- Keras 版本:2.2.4-tf
model = load_model(model_dir,
options_fn='model123_options',
weights_fn='model123_weights')
model
<tensorflow.python.keras.engine.training.Model at 0x7fe3d43d8e10>
est = tf.keras.estimator.model_to_estimator(keras_model=model)
我还尝试按如下方式添加我的自定义层,这会产生稍微不同的回溯,但最终会出现在同一个地方。下面的回溯来自定义了 custom_objects 的版本:
# custom_layer_names is a list of names of each of the custom layers in the trained model
custom_objects = {l.name: l for l in model.layers if l.name in custom_layer_names}
est = tf.keras.estimator.model_to_estimator(keras_model=model,
custom_objects=custom_objects)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpyujm6s99
INFO:tensorflow:Using the Keras model provided.
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-13-512a382c338c> in <module>()
13 est = tf.keras.estimator.model_to_estimator(keras_model=model,
14 model_dir='saved_estimator/',
---> 15 custom_objects=custom_objects)
~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow/python/keras/estimator/__init__.py in model_to_estimator(keras_model, keras_model_path, custom_objects, model_dir, config)
71 custom_objects=custom_objects,
72 model_dir=model_dir,
---> 73 config=config)
74
75 # LINT.ThenChange(//tensorflow_estimator/python/estimator/keras.py)
~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py in model_to_estimator(keras_model, keras_model_path, custom_objects, model_dir, config)
448 if keras_model._is_graph_network:
449 warm_start_path = _save_first_checkpoint(keras_model, custom_objects,
--> 450 config)
451 elif keras_model.built:
452 logging.warning('You are creating an Estimator from a Keras model manually '
~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py in _save_first_checkpoint(keras_model, custom_objects, config)
316 training_util.create_global_step()
317 model = _clone_and_build_model(ModeKeys.TRAIN, keras_model,
--> 318 custom_objects)
319 # save to checkpoint
320 with session.Session(config=config.session_config) as sess:
~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py in _clone_and_build_model(mode, keras_model, custom_objects, features, labels)
199 compile_clone=compile_clone,
200 in_place_reset=(not keras_model._is_graph_network),
--> 201 optimizer_iterations=global_step)
202
203 return clone
~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow/python/keras/models.py in clone_and_build_model(model, input_tensors, target_tensors, custom_objects, compile_clone, in_place_reset, optimizer_iterations, optimizer_config)
534 if custom_objects:
535 with CustomObjectScope(custom_objects):
--> 536 clone = clone_model(model, input_tensors=input_tensors)
537 else:
538 clone = clone_model(model, input_tensors=input_tensors)
~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow/python/keras/models.py in clone_model(model, input_tensors, clone_function)
324 else:
325 return _clone_functional_model(
--> 326 model, input_tensors=input_tensors, layer_fn=clone_function)
327
328
~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow/python/keras/models.py in _clone_functional_model(model, input_tensors, layer_fn)
152 # Get or create layer.
153 if layer not in layer_map:
--> 154 new_layer = layer_fn(layer)
155 layer_map[layer] = new_layer
156 layer = new_layer
~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow/python/keras/models.py in _clone_layer(layer)
52
53 def _clone_layer(layer):
---> 54 return layer.__class__.from_config(layer.get_config())
55
56
~/repos/grademachine/grademachine/utils.py in from_config(cls, config, custom_objects)
850 config = config.copy()
851 function = cls._parse_function_from_config(
--> 852 config, custom_objects, 'function', 'module', 'function_type')
853
854 output_shape = cls._parse_function_from_config(
~/repos/grademachine/grademachine/utils.py in _parse_function_from_config(cls, config, custom_objects, func_attr_name, module_attr_name, func_type_attr_name)
898 config[func_attr_name],
899 custom_objects=custom_objects,
--> 900 printable_module_name='function in Lambda layer')
901 elif function_type == 'lambda':
902 # Unsafe deserialization from bytecode
~/anaconda2/envs/berttf114/lib/python3.6/site-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
207 obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
208 else:
--> 209 obj = module_objects.get(object_name)
210 if obj is None:
211 raise ValueError('Unknown ' + printable_module_name + ':' + object_name)
AttributeError: 'NoneType' object has no attribute 'get'
解决方案
推荐阅读
- node.js - 如何在 NestJS 中正确使用 keycloak
- java - Java8 对 Map 中的 Map 值进行排序
- java - 无法从 FireBase 下载图像
- java - 使用 ResultSet 的方法 first() 时出现“Feature not supported”错误
- java - 即使我的用户已经订阅(沙盒),iap.subscribe(Skus[]) 机制也会返回 false
- python - 合并/组合单列中的重复项而不会丢失其他列中的数据
- c# - 基于登录用户我想隐藏或显示一些动作视图
- angular - 如何从Angular中的一组输入字段中禁用输入字段
- c# - What does .ToUniversalTime() really do on a DateTime instance (C#)?
- swiftui - 拖放列表以在 SwiftUI 上重新排序