首页 > 解决方案 > 结合GridSearch自定义交叉验证功能

问题描述

我目前正在自学 Python 和机器学习,并且正在从事一个处理分类的项目。我有可执行代码,我想自己重写并学习它。现在我已经到了靠自己无法前进的地步。我使用具有不同分类器的集成,例如 AdaBoost、CatBoost、XGBoost 等。

第一个函数是自定义的cross val函数,可以理解。第二个函数是GridSearch的扩展函数,我不是很了解,现在想改写成一个“普通的GridSearch函数”。我将不胜感激有关第二个功能的任何提示和帮助

可以在此处找到名为 ParamSearch 的原始“自定义”GridSearch: https ://effectiveml.com/files/paramsearch.py

def crossvaltest_cat(params, X, y, n_splits=5):
    skf = StratifiedKFold(n_splits=5)
    accuracy, score, f1 = [], [], []
    for train_index, test_index in skf.split(X, y):
        X_train, X_test = X.iloc[train_index, :], X.iloc[test_index, :]
        y_train, y_test = y.iloc[train_index], y.iloc[test_index]
        
        clf = CatBoostClassifier(**params)
        clf.fit(X_train, y_train)
        
        y_pred = np.array(clf.predict(X_test))
        tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
        
        accuracy.append(accuracy_score(y_test, y_pred))
        score.append(score_function(tp,fp,fn,tn))
        f1.append(f1_score(y_test, y_pred))
        
    return np.mean(score)

def cat_param_tune(params, X, y ,n_splits=5):
    ps = paramsearch(params)
    for prms in chain(ps.grid_search(['border_count']),
                      ps.grid_search(['l2_leaf_reg']),
                      ps.grid_search(['iterations','learning_rate']),
                      ps.grid_search(['depth'])):
        res = crossvaltest_cat(prms,X, y,n_splits)
        ps.register_result(res,prms)
        print(res,prms,'best:',ps.bestscore(),ps.bestparam())
        print()
    return ps.bestparam(), ps.bestscore()

标签: pythoncross-validationgridsearchcv

解决方案


推荐阅读