首页 > 解决方案 > 尽管传递了自定义对象,但无法在使用 `tf.keras.models.load_model()` 加载的模型上使用 `call()`

问题描述

我正在尝试SharedModel使用自定义块 ( ClassA, ClassB, ) 构建自定义模型 ( ClassC)。可以在此处找到视觉概览。建立这个模型的原因是能够按原样训练它,然后class_c在推理过程中去掉它。这是一个玩具模型,以简洁地代表实际模型。

代码ClassA

class ClassA(keras.Model):
    def __init__(self, **kwargs):
        super(ClassA, self).__init__(**kwargs)
        weight_decay = 0.0001
        self.L1 = Conv2D(64, kernel_size=(3,3),padding = "same",kernel_regularizer=regularizers.l2(weight_decay))
        self.L2 = BatchNormalization()
        self.L3 = Activation('relu')
        self.L4 = MaxPooling2D(pool_size=(2,2), name="pool1")
    
    def call(self, inputs):
        x=self.L1(inputs)
        x=self.L2(x)
        x=self.L3(x)
        x=self.L4(x)
        return x

    def get_config(self):
        return {
            # "layers": {
            "L1": self.L1,
            "L2": self.L2,
            "L3": self.L3,
            "L4": self.L4,
            # }
        }

和的实现ClassBClassC相同ClassA

代码SharedModel

# @tf.keras.utils.register_keras_serializable()
class SharedModel(keras.Model):
    def __init__(self, **kwargs):
        super(SharedModel, self).__init__(**kwargs)
        self.a_class=ClassA()
        self.b_class=ClassB()
        self.c_class=ClassC()
    
    def call(self, inputs, **kwargs):
        out1=self.a_class(inputs)
        out2=self.b_class(out1)
        out3=self.c_class(out1)

        return out2, out3
    
    def get_config(self):
        return {
            # "layers": {
            "a_class": self.a_class,
            "b_class": self.b_class,
            "c_class": self.c_class,
            # "build_graph": self.build_graph
            # }
        }
        # "base_config": super(SharedModel, self).get_config()}

    @classmethod
    def from_config(cls, config):
        return cls(**config)

    # @tf.keras.utils.register_keras_serializable()
    def build_graph(self, dim):
        x = Input(shape=(dim))
        return Model(inputs=x, outputs = self.call(inputs=x), name="Shared Model")

我能够毫无问题地构建和保存模型。但是,当我尝试使用

recreate_model=tf.keras.models.load_model("GraphModel", custom_objects={"SharedModel": SharedModel})

并尝试使用其中一个recreate_model(input)甚至是显式调用recreate_model.call(input),我收到以下错误:

/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/load.py in _unable_to_call_layer_due_to_serialization_issue(layer, *unused_args, **unused_kwargs)
    902 
    903   raise ValueError(
--> 904       f'Cannot call custom layer {layer.name} of type {type(layer)}, because '
    905       'the call function was not serialized to the SavedModel.'
    906       'Please try one of the following methods to fix this issue:'

ValueError: Exception encountered when calling layer "shared_model" (type SharedModel).

Cannot call custom layer shared_model of type <class 'keras.saving.saved_model.load.SharedModel'>, because the call function was not serialized to the SavedModel.Please try one of the following methods to fix this issue:

(1) Implement `get_config` and `from_config` in the layer/model class, and pass the object to the `custom_objects` argument when loading the model. For more details, see: https://www.tensorflow.org/guide/keras/save_and_serialize

(2) Ensure that the subclassed model or layer overwrites `call` and not `__call__`. The input shape and dtype will be automatically recorded when the object is called, and used when saving. To manually specify the input shape/dtype, decorate the call function with `@tf.function(input_signature=...)`.

Call arguments received:
  • unused_args=('tf.Tensor(shape=(None, 216, 64, 1), dtype=float32)',)
  • unused_kwargs={'training': 'None'}

如图所示,我已经get_config为所有子类和from_configfor实现了该方法,SharedModel并且正在传入SharedModel. custom_objects我也尝试将块传递给custom_ojects,但是在加载时出现此错误:

1 frames
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

/usr/local/lib/python3.7/dist-packages/keras/engine/functional.py in reconstruct_from_config(config, custom_objects, created_layers)
   1323 
   1324   # First, we create all layers and enqueue nodes to be processed
-> 1325   for layer_data in config['layers']:
   1326     process_layer(layer_data)
   1327   # Then we process nodes in order of layer depth.

KeyError: 'layers'

我确实也想过from_config为这些块实现,但它似乎是可选的。

根据文档,在传递自定义对象时recreate_model应该具有与model( ) 类似的对象 id 类,但它有。<__main__.SharedModel><keras.saving.saved_model.load.SharedModel>

从这些错误中,我猜我在 and 中遗漏了一些get_config东西from_config。有人可以告诉我我做错了什么并指出解决方案吗?此外,如何build_graphSharedModel加载后访问?

标签: pythontensorflowkerastensorflow2.0

解决方案


推荐阅读