首页 > 解决方案 > 类型错误:('关键字参数不理解:','pool1')尝试使用自定义层加载模型时

问题描述

我正在尝试使用以下代码加载具有自定义层的模型:

model = load_model(BEST_MODEL_PATH, compile=False, custom_objects={'Localization': Localization})

但我得到了错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-73-668917c4ab3c> in <module>()
      3 BEST_MODEL_PATH = '/content/data/00002-test-train/model-01-0.903.hdf5'
      4 
----> 5 model = load_model(BEST_MODEL_PATH, compile=False, custom_objects={'Localization': Localization})
      6 
      7 predictions = model.predict(X_test)

14 frames
/usr/local/lib/python3.7/dist-packages/keras/saving/save.py in load_model(filepath, custom_objects, compile, options)
    199             (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
    200           return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
--> 201                                                   compile)
    202 
    203         filepath = path_to_string(filepath)

/usr/local/lib/python3.7/dist-packages/keras/saving/hdf5_format.py in load_model_from_hdf5(filepath, custom_objects, compile)
    179     model_config = json_utils.decode(model_config)
    180     model = model_config_lib.model_from_config(model_config,
--> 181                                                custom_objects=custom_objects)
    182 
    183     # set weights

/usr/local/lib/python3.7/dist-packages/keras/saving/model_config.py in model_from_config(config, custom_objects)
     50                     '`Sequential.from_config(config)`?')
     51   from keras.layers import deserialize  # pylint: disable=g-import-not-at-top
---> 52   return deserialize(config, custom_objects=custom_objects)
     53 
     54 

/usr/local/lib/python3.7/dist-packages/keras/layers/serialization.py in deserialize(config, custom_objects)
    210       module_objects=LOCAL.ALL_OBJECTS,
    211       custom_objects=custom_objects,
--> 212       printable_module_name='layer')

/usr/local/lib/python3.7/dist-packages/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    676             custom_objects=dict(
    677                 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 678                 list(custom_objects.items())))
    679       else:
    680         with CustomObjectScope(custom_objects):

/usr/local/lib/python3.7/dist-packages/keras/engine/functional.py in from_config(cls, config, custom_objects)
    661     with generic_utils.SharedObjectLoadingScope():
    662       input_tensors, output_tensors, created_layers = reconstruct_from_config(
--> 663           config, custom_objects)
    664       model = cls(inputs=input_tensors, outputs=output_tensors,
    665                   name=config.get('name'))

/usr/local/lib/python3.7/dist-packages/keras/engine/functional.py in reconstruct_from_config(config, custom_objects, created_layers)
   1271   # First, we create all layers and enqueue nodes to be processed
   1272   for layer_data in config['layers']:
-> 1273     process_layer(layer_data)
   1274   # Then we process nodes in order of layer depth.
   1275   # Nodes that cannot yet be processed (if the inbound node

/usr/local/lib/python3.7/dist-packages/keras/engine/functional.py in process_layer(layer_data)
   1253       from keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
   1254 
-> 1255       layer = deserialize_layer(layer_data, custom_objects=custom_objects)
   1256       created_layers[layer_name] = layer
   1257 

/usr/local/lib/python3.7/dist-packages/keras/layers/serialization.py in deserialize(config, custom_objects)
    210       module_objects=LOCAL.ALL_OBJECTS,
    211       custom_objects=custom_objects,
--> 212       printable_module_name='layer')

/usr/local/lib/python3.7/dist-packages/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    679       else:
    680         with CustomObjectScope(custom_objects):
--> 681           deserialized_obj = cls.from_config(cls_config)
    682     else:
    683       # Then `cls` may be a function returning a class.

/usr/local/lib/python3.7/dist-packages/keras/engine/base_layer.py in from_config(cls, config)
    746         A layer instance.
    747     """
--> 748     return cls(**config)
    749 
    750   def compute_output_shape(self, input_shape):

<ipython-input-55-f9b9fb9036d5> in __init__(self, filters_1, filters_2, fc_units, kernel_size, pool_size, **kwargs)
     14         self.fc1 = Dense(fc_units, activation='relu')
     15         self.fc2 = Dense(6, activation=None, bias_initializer=tf.keras.initializers.constant([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]), kernel_initializer='zeros')
---> 16         super(Localization, self).__init__(**kwargs)
     17 
     18     def build(self, input_shape):

/usr/local/lib/python3.7/dist-packages/tensorflow/python/training/tracking/base.py in _method_wrapper(self, *args, **kwargs)
    528     self._self_setattr_tracking = False  # pylint: disable=protected-access
    529     try:
--> 530       result = method(self, *args, **kwargs)
    531     finally:
    532       self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

/usr/local/lib/python3.7/dist-packages/keras/engine/base_layer.py in __init__(self, trainable, name, dtype, dynamic, **kwargs)
    321     }
    322     # Validate optional keyword arguments.
--> 323     generic_utils.validate_kwargs(kwargs, allowed_kwargs)
    324 
    325     # Mutable properties

/usr/local/lib/python3.7/dist-packages/keras/utils/generic_utils.py in validate_kwargs(kwargs, allowed_kwargs, error_message)
   1141   for kwarg in kwargs:
   1142     if kwarg not in allowed_kwargs:
-> 1143       raise TypeError(error_message, kwarg)
   1144 
   1145 

TypeError: ('Keyword argument not understood:', 'pool1')

我的自定义层:

class Localization(tf.keras.layers.Layer):
    def __init__(self, filters_1, filters_2, fc_units, kernel_size=(5,5), \
                 pool_size=(2,2), **kwargs):
        self.filters_1 = filters_1
        self.filters_2 = filters_2
        self.fc_units = fc_units
        self.kernel_size = kernel_size
        self.pool_size = pool_size
        self.pool1 = MaxPooling2D(pool_size=pool_size)
        self.conv1 = Conv2D(filters=filters_1, kernel_size=kernel_size, padding='same', strides=1, activation='relu')
        self.pool2 = MaxPooling2D(pool_size=pool_size)
        self.conv2 = Conv2D(filters=filters_2, kernel_size=kernel_size, padding='same', strides=1, activation='relu')
        self.pool3 = MaxPooling2D(pool_size=pool_size)
        self.flatten = Flatten()
        self.fc1 = Dense(fc_units, activation='relu')
        self.fc2 = Dense(6, activation=None, bias_initializer=tf.keras.initializers.constant([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]), kernel_initializer='zeros')
        super(Localization, self).__init__(**kwargs)

    def build(self, input_shape):
        print("Building Localization Network with input shape:", input_shape)

    def compute_output_shape(self, input_shape):
        return [None, 6]

    def call(self, inputs):
        x = self.pool1(inputs)
        x = self.conv1(x)
        x = self.pool2(x)
        x = self.conv2(x)
        x = self.pool3(x)
        x = self.flatten(x)
        x = self.fc1(x)
        theta = self.fc2(x)
        theta = tf.keras.layers.Reshape((2, 3))(theta)
        return theta

    def get_config(self):
        config = super(Localization, self).get_config()
        config.update({
            'filters_1': self.filters_1,
            'filters_2': self.filters_2,
            'fc_units': self.fc_units,
            'kernel_size': self.kernel_size,
            'pool_size': self.pool_size,
            'pool1': self.pool1,
            'conv1': self.conv1,
            'pool2': self.pool2,
            'conv2': self.conv2,
            'pool3': self.pool3,
            'flatten': self.flatten,
            'fc1': self.fc1,
            'fc2': self.fc2,
        })
        return config

我正在使用升级到 tensorflow 2.6 的 google colab。在 tensorflow 2.5 中我没有这样的问题。

我被迫对图层进行许多更改以使其正常工作。在2.5中,我不需要assign filters_1filters_2而另一个args__init__(因为我没有在其他地方使用它们),传递**kwargs并编写get_config函数。

我什至尝试再次安装 tensorflow 2.5 和 keras,但在训练时会出错。我尝试了很多东西,搜索了文档并阅读了几乎所有类似的问题,但找不到任何东西。

标签: pythontensorflowkerasdeep-learning

解决方案


正如史努比博士他们的评论中所说,我删除了这些层get_config,现在它工作正常。


推荐阅读