python - 尽管传递了自定义对象,但无法在使用 `tf.keras.models.load_model()` 加载的模型上使用 `call()`
问题描述
我正在尝试SharedModel
使用自定义块 ( ClassA
, ClassB
, ) 构建自定义模型 ( ClassC
)。可以在此处找到视觉概览。建立这个模型的原因是能够按原样训练它,然后class_c
在推理过程中去掉它。这是一个玩具模型,以简洁地代表实际模型。
代码ClassA
:
class ClassA(keras.Model):
def __init__(self, **kwargs):
super(ClassA, self).__init__(**kwargs)
weight_decay = 0.0001
self.L1 = Conv2D(64, kernel_size=(3,3),padding = "same",kernel_regularizer=regularizers.l2(weight_decay))
self.L2 = BatchNormalization()
self.L3 = Activation('relu')
self.L4 = MaxPooling2D(pool_size=(2,2), name="pool1")
def call(self, inputs):
x=self.L1(inputs)
x=self.L2(x)
x=self.L3(x)
x=self.L4(x)
return x
def get_config(self):
return {
# "layers": {
"L1": self.L1,
"L2": self.L2,
"L3": self.L3,
"L4": self.L4,
# }
}
和的实现ClassB
与ClassC
相同ClassA
。
代码SharedModel
:
# @tf.keras.utils.register_keras_serializable()
class SharedModel(keras.Model):
def __init__(self, **kwargs):
super(SharedModel, self).__init__(**kwargs)
self.a_class=ClassA()
self.b_class=ClassB()
self.c_class=ClassC()
def call(self, inputs, **kwargs):
out1=self.a_class(inputs)
out2=self.b_class(out1)
out3=self.c_class(out1)
return out2, out3
def get_config(self):
return {
# "layers": {
"a_class": self.a_class,
"b_class": self.b_class,
"c_class": self.c_class,
# "build_graph": self.build_graph
# }
}
# "base_config": super(SharedModel, self).get_config()}
@classmethod
def from_config(cls, config):
return cls(**config)
# @tf.keras.utils.register_keras_serializable()
def build_graph(self, dim):
x = Input(shape=(dim))
return Model(inputs=x, outputs = self.call(inputs=x), name="Shared Model")
我能够毫无问题地构建和保存模型。但是,当我尝试使用
recreate_model=tf.keras.models.load_model("GraphModel", custom_objects={"SharedModel": SharedModel})
并尝试使用其中一个recreate_model(input)
甚至是显式调用recreate_model.call(input)
,我收到以下错误:
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/load.py in _unable_to_call_layer_due_to_serialization_issue(layer, *unused_args, **unused_kwargs)
902
903 raise ValueError(
--> 904 f'Cannot call custom layer {layer.name} of type {type(layer)}, because '
905 'the call function was not serialized to the SavedModel.'
906 'Please try one of the following methods to fix this issue:'
ValueError: Exception encountered when calling layer "shared_model" (type SharedModel).
Cannot call custom layer shared_model of type <class 'keras.saving.saved_model.load.SharedModel'>, because the call function was not serialized to the SavedModel.Please try one of the following methods to fix this issue:
(1) Implement `get_config` and `from_config` in the layer/model class, and pass the object to the `custom_objects` argument when loading the model. For more details, see: https://www.tensorflow.org/guide/keras/save_and_serialize
(2) Ensure that the subclassed model or layer overwrites `call` and not `__call__`. The input shape and dtype will be automatically recorded when the object is called, and used when saving. To manually specify the input shape/dtype, decorate the call function with `@tf.function(input_signature=...)`.
Call arguments received:
• unused_args=('tf.Tensor(shape=(None, 216, 64, 1), dtype=float32)',)
• unused_kwargs={'training': 'None'}
如图所示,我已经get_config
为所有子类和from_config
for实现了该方法,SharedModel
并且正在传入SharedModel
. custom_objects
我也尝试将块传递给custom_ojects
,但是在加载时出现此错误:
1 frames
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
/usr/local/lib/python3.7/dist-packages/keras/engine/functional.py in reconstruct_from_config(config, custom_objects, created_layers)
1323
1324 # First, we create all layers and enqueue nodes to be processed
-> 1325 for layer_data in config['layers']:
1326 process_layer(layer_data)
1327 # Then we process nodes in order of layer depth.
KeyError: 'layers'
我确实也想过from_config
为这些块实现,但它似乎是可选的。
根据文档,在传递自定义对象时recreate_model
应该具有与model
( ) 类似的对象 id 类,但它有。<__main__.SharedModel>
<keras.saving.saved_model.load.SharedModel>
从这些错误中,我猜我在 and 中遗漏了一些get_config
东西from_config
。有人可以告诉我我做错了什么并指出解决方案吗?此外,如何build_graph
在SharedModel
加载后访问?
解决方案
推荐阅读
- php - 如何在 Laravel 中自定义对象集合
- php - 24小时后在php中从数据库中删除行
- javascript - 如何清除从 JQuery.HTML() 检索到的 javascript 对象
- r - 寻找类似 SIMPROF 的聚类分析,但允许每个类别进行许多观察
- python - 如果它是其他行的子集,则熊猫会丢弃行
- ssis - SSIS如何选择并行执行的数据流任务?
- javascript - jQuery单击事件未按预期工作
- python - 如何让 sys.stdout 将结果打印到 tkinter GUI 文本框中
- oracle - Oracle COUNT(DISTINCT expr) 导致 ORA-00979 错误
- ios - 无法从 GET 请求将数据附加到数组