tensorflow - 保存 TensorFlow 神经网络 KFold 交叉验证模型
问题描述
我正在使用 TensorFlow 2.4.1 研究具有 KFold 交叉验证的示例神经网络。和sklearn。不幸的是,我无法保存模型。
def my_model(self,):
inputs = keras.Input(shape=(48, 48, 3))
x = layers.Conv2D(filters=4, kernel_size=self.k_size, padding='same', activation="relu")(inputs)
x = layers.BatchNormalization()(x)
x = layers.MaxPool2D()(x)
x = layers.Flatten()(x)
output = layers.Dense(10, activation='softmax')(x)
model = keras.Model(inputs=inputs, outputs=output)
model.compile(optimizer='adam',
loss=[keras.losses.SparseCategoricalCrossentropy(from_logits=True)],
metrics=['accuracy'])
return model
def train_model(self):
try:
os.mkdir('model/saved_models')
except OSError:
pass
try:
os.mkdir('model/saved_graphs')
except OSError:
pass
kf = KFold(n_splits=3)
for train_index, test_index in kf.split(self.x_train):
x_train, x_test = self.x_train[train_index], self.x_train[test_index]
y_train, y_test = self.y_train[train_index], self.y_train[test_index]
model = self.my_model()
print(model.summary())
trained_model = model.fit(x_train, y_train, epochs=self.epochs, steps_per_epoch=10, verbose=2)
trained_model = trained_model.history
print('Model evaluation', model.evaluate(x_test, y_test, verbose = 2))
trained_model.save(f'model/saved_models/dummy_model_{date}')
return trained_model
我收到以下错误:
trained_model.save(f'model/saved_models/dummy_model_{date}')
AttributeError: 'dict' object has no attribute 'save'
我无法想出一种将训练模型从 for 循环中取出的方法。这可能是我能想到的这个问题的可能原因。
有人可以建议我们如何解决这个问题吗?或者有没有其他方法可以使用 KFold 构建 ANN?
谢谢。
解决方案
是的,您的代码有一些错字:
trained_model = trained_model.history # This is your train stats, so your train stats is a dictionary
model.save(f'model/saved_models/dummy_model_{date}') # This is what your saving the actual model
推荐阅读
- ios - ScrollView 无法与 GeometryReader 一起正常工作
- url - 如何从 url 下载 m3u 文件以及帐户 iptv 的所有详细信息?
- sql-server - 在 sql server 2019 中设置 cdc 并注册 debezium sql server 连接器后运行 kafka consumer 时无法生成任何日志
- java - java.lang.OutOfMemoryError: 问题
- cygwin - 如何使用 cygwin 执行 wget 脚本?
- python - 为什么 rect 参数在 pygame 填充中不起作用?
- php - 我怎么知道我的客户端何时从不同的设备登录到我的应用程序
- r - R:data.table-将变量添加到数据表列表中,其中包含列表中每个表的名称
- html - Flexbox 和图像
- docker - Docker 最佳实践:使用操作系统或应用程序作为基础镜像?