首页 > 解决方案 > How to get summary graph of custom (subclass) Keras layer?

问题描述

How do you print a summary() of the layers of a custom layer?

model.summary() prints a beautiful summary graph of the entire model, but the subclass layer called 'magic_layer' here, which has many layers within it, is aggregated...

Model: "transformer"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
positioning (PositionalEncod (None, 504, 6)            0         
_________________________________________________________________
magic_layer (CustomLayer)    (None, 6, 504)            3088040   
_________________________________________________________________
g_pooling (GlobalAveragePool (None, 504)               0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 504)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 32)                16160     
_________________________________________________________________
dropout_3 (Dropout)          (None, 32)                0         
_________________________________________________________________
dense_3 (Dense)              (None, 5)                 165       
=================================================================
Total params: 3,104,365
Trainable params: 3,104,365
Non-trainable params: 0
_________________________________________________________________

If you have a custom-defined Tensorflow/Keras layer (read more about that here: Making new layers and models via subclassing - Francis Chollet) then the summary call won't break out all the layers in that sublayer. 'magic_layer' in this example, is the subclass layer that I'm interested in.

How could you get this same print-out of sublayers for the layer called 'magic_layer' in this example?

model.layers[1].summary() does not work unfortunately... Perhaps I need to include a summary def in the custom layer class, but I was hoping there was a way of inheriting this functionality from the model class.

标签: tensorflowkeras

解决方案


由于模型是层的子类,只需从 tf.keras.Model 而不是 tf.keras.layers.Layer 制作自定义层的子类。现在您可以通过 summary() 打印“层”的摘要。

model.summary 不是递归的——它不会打印嵌入式模型的摘要。如果需要,您必须自己编写,或者仅根据原始源创建自己的摘要函数。

https://github.com/keras-team/keras/blob/07a22914c8114a74238fd86741749cab5af299ce/keras/utils/layer_utils.py#L116


推荐阅读