首页 > 解决方案 > 绘图模型不显示模型层,仅显示模型名称

问题描述

我正在尝试使用 TensorFlow2 构建一些模型,因此我创建了一个模型类,如下所示:

import tensorflow as tf

class Dummy(tf.keras.Model):
    def __init__(self, name="dummy"):
        super(Dummy, self).__init__()
        self._name = name

        self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)

    def call(self, inputs, training=False):
        x = self.dense1(inputs)
        return self.dense2(x)

model = Dummy()
model.build(input_shape=(None,5))

现在我想绘制模型,同时使用summary()返回我期望的结果,plot_model(model, show_shapes=True, expand_nested=True)只返回一个带有模型名称的块。

如何返回我的模型图?

标签: pythonpython-3.xtensorflowkerastensorflow2.x

解决方案


Francois Chollet 说:

您可以在功能或顺序模型中执行所有这些操作(打印输入/输出形状),因为这些模型是层的静态图。

相反,子类模型是一段 Python 代码(调用方法)。这里没有图层图。我们无法知道层是如何相互连接的(因为这是在调用主体中定义的,而不是作为显式数据结构),因此我们无法推断输入/输出形状。

有两种解决方案:

  1. 您可以按顺序/使用功能 api 构建模型。
  2. 您将您的 ' ' 函数包装call成一个函数模型,如下所示:

class Subclass(Model)

def __init__(self):
    ...
def call(self, x):
    ...

def model(self):
    x = Input(shape=(24, 24, 3))
    return Model(inputs=[x], outputs=self.call(x))


if __name__ == '__main__':
    sub = subclass()
    sub.model().summary()

答案取自这里:model.summary() can't print output shape while using subclass model

此外,这是一篇很好的文章:https ://medium.com/tensorflow/what-are-symbolic-and-imperative-apis-in-tensorflow-2-0-dfccecb01021


推荐阅读