首页 > 解决方案 > Keras:保存模型定义为一个类引发 NotImplementedError

问题描述

在阅读了在我的情况下不起作用的类似问题和答案后,我正在写这篇文章。你可能注意到我在第一层定义了输入形状。

我在 Keras 中创建了一个非常小的 CNN,如下:

import tensorflow as tf

class MyNet(tf.keras.Model):
    def __init__(self):
        super(MyNet, self).__init__()
         self.conv1 = tf.keras.layers.Conv2D(32, 5, strides = (2,2), data_format = 'channels_first', input_shape = (3,224,224))
         self.bn1 = tf.keras.layers.BatchNormalization(axis = 1)
         self.fc1 = tf.keras.layers.Dense(10)
         self.globalavg = tf.keras.layers.GlobalAveragePooling2D(data_format = 'channels_first')

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = tf.keras.activations.relu(x)
        x = self.globalavg(x)

        return self.fc1(x)

然后我在其中输入了一些东西并成功打印了结果(此时权重可能是随机的,但没关系):

image = tf.ones(shape = (1, 3, 224, 224))  # Defined "channels first" when created the layers
mynet = MyNet()
outputs = mynet(image)
print(tf.keras.backend.eval(outputs))

我在这一步看到的结果是该fc1层的 10 个输出:

[[-1.1747773  -0.21640654 -0.16266493 -0.44879064 -0.642066    0.78132695  -0.03920581 -0.30874395 -0.04169023 -0.10409291]]

然后我尝试通过调用保存模型及其权重,mynet.save('mynet.hdf5')并得到以下错误:

NotImplementedError: Currently `save` requires model to be a graph network. Consider using `save_weights`, in order to save the weights of the model.

请注意,我是 Keras 的新手,我的大部分经验都是使用 PyTorch。

我究竟做错了什么?

更新:

按照@ikibir 的回答,我将网络重新定义为顺序网络:

myNetAsSeq = tf.keras.models.Sequential()
myNetAsSeq.add(tf.keras.layers.Conv2D(32, 5, strides = (2,2), data_format = 'channels_first', input_shape = (3,224,224)))
myNetAsSeq.add(tf.keras.layers.BatchNormalization(axis = 1))
myNetAsSeq.add(tf.keras.layers.Activation('relu'))
myNetAsSeq.add(tf.keras.layers.GlobalAveragePooling2D(data_format = 'channels_first'))
myNetAsSeq.add(tf.keras.layers.Dense(10))

这次调用myNetAsSeq.save('mynet.hdf5')成功了。

标签: keras

解决方案


我不确定我的答案,但我相信您不会创建模型,您只是单独创建每个层,当您运行“调用”函数时,您只需将变量传递给这些层。

在 keras 你应该使用

model = models.Sequential() 

用于创建模型,您应该使用

model.add()

添加图层

然后你可以保存这个模型


推荐阅读