首页 > 解决方案 > 在 K-Fold Cross 验证中,在 keras 中在哪里创建模型对象?

问题描述

在 K-fold 循环内部还是外部创建 Keras 模型对象?请解释为什么您的答案是正确的。

def model_def(): 
     model = Sequential()
     model.add(.... so on....)
     model.compile(....so on ....)
     return model

案例 1:- 在 K-fold 循环内,因此它正在为每个循环重新创建

for train_index, test_index in kf.split(X,Y):
     model = model_def()
     model.fit(X[train_index],Y[test_index] ..... so on .....

或者,案例 2:- 在循环之外,因此所有折叠循环的单个模型对象

model = model_def()
for train_index, test_index in kf.split(X,Y):
     model.fit(X[train_index],Y[test_index] ..... so on .....

标签: kerasscikit-learndeep-learningcross-validationk-fold

解决方案


里面。

对于每一次折叠,您都希望拥有一个全新的模型。这意味着您的模型不能通过来自另一个折叠的数据来学习任何权重(如果您在内部进行,则会发生这种情况,因为在每个折叠中您都在同一个实例上操作)。k-fold 学习的目的是检查你的模型在一小部分数据集上的表现,所以它不应该有任何关于其他折叠数据的信息。


推荐阅读