keras - Keras:保存模型定义为一个类引发 NotImplementedError
问题描述
在阅读了在我的情况下不起作用的类似问题和答案后,我正在写这篇文章。你可能注意到我在第一层定义了输入形状。
我在 Keras 中创建了一个非常小的 CNN,如下:
import tensorflow as tf
class MyNet(tf.keras.Model):
def __init__(self):
super(MyNet, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(32, 5, strides = (2,2), data_format = 'channels_first', input_shape = (3,224,224))
self.bn1 = tf.keras.layers.BatchNormalization(axis = 1)
self.fc1 = tf.keras.layers.Dense(10)
self.globalavg = tf.keras.layers.GlobalAveragePooling2D(data_format = 'channels_first')
def call(self, inputs):
x = self.conv1(inputs)
x = self.bn1(x)
x = tf.keras.activations.relu(x)
x = self.globalavg(x)
return self.fc1(x)
然后我在其中输入了一些东西并成功打印了结果(此时权重可能是随机的,但没关系):
image = tf.ones(shape = (1, 3, 224, 224)) # Defined "channels first" when created the layers
mynet = MyNet()
outputs = mynet(image)
print(tf.keras.backend.eval(outputs))
我在这一步看到的结果是该fc1
层的 10 个输出:
[[-1.1747773 -0.21640654 -0.16266493 -0.44879064 -0.642066 0.78132695 -0.03920581 -0.30874395 -0.04169023 -0.10409291]]
然后我尝试通过调用保存模型及其权重,mynet.save('mynet.hdf5')
并得到以下错误:
NotImplementedError: Currently `save` requires model to be a graph network. Consider using `save_weights`, in order to save the weights of the model.
请注意,我是 Keras 的新手,我的大部分经验都是使用 PyTorch。
我究竟做错了什么?
更新:
按照@ikibir 的回答,我将网络重新定义为顺序网络:
myNetAsSeq = tf.keras.models.Sequential()
myNetAsSeq.add(tf.keras.layers.Conv2D(32, 5, strides = (2,2), data_format = 'channels_first', input_shape = (3,224,224)))
myNetAsSeq.add(tf.keras.layers.BatchNormalization(axis = 1))
myNetAsSeq.add(tf.keras.layers.Activation('relu'))
myNetAsSeq.add(tf.keras.layers.GlobalAveragePooling2D(data_format = 'channels_first'))
myNetAsSeq.add(tf.keras.layers.Dense(10))
这次调用myNetAsSeq.save('mynet.hdf5')
成功了。
解决方案
我不确定我的答案,但我相信您不会创建模型,您只是单独创建每个层,当您运行“调用”函数时,您只需将变量传递给这些层。
在 keras 你应该使用
model = models.Sequential()
用于创建模型,您应该使用
model.add()
添加图层
然后你可以保存这个模型
推荐阅读
- angular - 如何检查Angular反应形式中的所有复选框
- ms-access - MS-Access 中计数值的百分比
- javascript - 如何使用 JavaScript 将 csv 文件从 NetSuite 中的文件柜传递到 ftp 服务器?
- cordova - 升级到 Android Pie 后不会发生 onregistered 事件
- c - 如何设计测试用例来验证数据包解码器的节流能力?
- javascript - 如何在 React 应用程序中执行 Python 代码?
- awk - 如何使用 awk 连接特定行?
- google-cloud-platform - 如何删除大查询分区表?
- python - 如何使用 Python 登录网站
- ios - How to find the iOS SDK version that an app is using?