首页 > 解决方案 > 保存 TensorFlow 神经网络 KFold 交叉验证模型

问题描述

我正在使用 TensorFlow 2.4.1 研究具有 KFold 交叉验证的示例神经网络。和sklearn。不幸的是,我无法保存模型。

def my_model(self,):
            inputs = keras.Input(shape=(48, 48, 3))
            x = layers.Conv2D(filters=4, kernel_size=self.k_size, padding='same', activation="relu")(inputs)
            x = layers.BatchNormalization()(x)
            x = layers.MaxPool2D()(x)
            x = layers.Flatten()(x)
            output = layers.Dense(10, activation='softmax')(x)
    
            model = keras.Model(inputs=inputs, outputs=output)
    
            model.compile(optimizer='adam',
                          loss=[keras.losses.SparseCategoricalCrossentropy(from_logits=True)],
                          metrics=['accuracy'])
            return model
    
def train_model(self):
    
            try:
                os.mkdir('model/saved_models')
            except OSError:
                pass
    
            try:
                os.mkdir('model/saved_graphs')
            except OSError:
                pass
    
            kf = KFold(n_splits=3)
            for train_index, test_index in kf.split(self.x_train):
                x_train, x_test = self.x_train[train_index], self.x_train[test_index]
                y_train, y_test = self.y_train[train_index], self.y_train[test_index]
                model = self.my_model()
                print(model.summary())
                trained_model = model.fit(x_train, y_train, epochs=self.epochs, steps_per_epoch=10, verbose=2)
                trained_model = trained_model.history
                print('Model evaluation', model.evaluate(x_test, y_test, verbose = 2))
                trained_model.save(f'model/saved_models/dummy_model_{date}')
                return trained_model

我收到以下错误:

    trained_model.save(f'model/saved_models/dummy_model_{date}')
AttributeError: 'dict' object has no attribute 'save'

我无法想出一种将训练模型从 for 循环中取出的方法。这可能是我能想到的这个问题的可能原因。

有人可以建议我们如何解决这个问题吗?或者有没有其他方法可以使用 KFold 构建 ANN?

谢谢。

标签: tensorflowmachine-learningscikit-learndeep-learning

解决方案


是的,您的代码有一些错字:

trained_model = trained_model.history # This is your train stats, so your train stats is a dictionary
model.save(f'model/saved_models/dummy_model_{date}') # This is what your saving the actual model

推荐阅读