首页 > 解决方案 > 如何确定 cross_validate 是否使用分层 K 折?

问题描述

我想确保cross_validate使用分层简历。在 的文档中cross_validate,有这样写

对于 int/None 输入,如果估计器是分类器并且 y 是二元或多类,则使用 StratifiedKFold。在所有其他情况下,使用 KFold。

我的估计器是一个分类器,我的因变量是二元的。所以理论上也通过设置cv=None我应该得到一个分层的简历。

我怎么能确定呢?如何检查是否cross_validate在这里:

rfc_score = cross_validate(rfc, desc_tfidf, labels, scoring=metrics)

真的是在使用分层简历吗?

标签: pythonscikit-learncross-validation

解决方案


从 的源代码cross-validate,我们可以看到该方法运行的第一件事是:

cv = check_cv(cv, y, classifier=is_classifier(estimator))

在 中check_cv,我们有:

cv = 5 if cv is None else cv
if isinstance(cv, numbers.Integral):
    if (classifier and (y is not None) and
            (type_of_target(y) in ('binary', 'multiclass'))):
        return StratifiedKFold(cv)
    else:
        return KFold(cv)

这正是文档所声称的。


推荐阅读