首页 > 解决方案 > 网格搜索中具有第三个参数的自定义记分器

问题描述

我需要制作一个计分器,根据三个列表/数组y_truey_predsample_value. 问题是网格搜索中的记分器计算了训练集和验证集的分数,我不知道如何区分它。这就是我尝试做的事情(完整示例):

import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer

def RF_metric(y_true, y_pred, sample_value):

    dict_temp = {'y_pred': list(y_pred), 'y_true': list(y_true),
                 'sample_value': sample_value}
    
    df_temp = pd.DataFrame(dict_temp)
    df_temp['daily_score'] = df_temp[['y_pred', 'y_true', 'sample_value']].apply(
        lambda row: row[2] if row[0] == row[1] else -row[2], axis=1)

    df_temp['cum_score'] = df_temp['daily_score'].cumsum()
    final_score = df_temp['cum_score'].to_list()[-1]

    return final_score

param_dict = {'n_estimators': [100, 150, 200],
                  'max_depth': [5, 10, 15],
                  }

dates = pd.date_range(start='2020-01-01', end='2020-10-01')
df = pd.DataFrame({'Date': dates, 'A':np.random.rand(len(dates)), 'B':np.random.rand(len(dates)),
                   'label':np.random.choice([0,1],len(dates)), 'sample_value':np.random.rand(len(dates))})

train_start = pd.to_datetime('2020-01-01')
train_end = pd.to_datetime('2020-06-01')
val_start = train_end
val_end = pd.to_datetime('2020-07-01')

df_train = df[(train_start <= df['Date']) & (df['Date'] < train_end)]
df_val = df[(val_start <= df['Date']) & (df['Date'] < val_end)]

cv_list = [(list(df_train.index), list(df_val.index))]
X = df[['A', 'B']].values
Y = df[['label']].values.ravel()

clf = RandomForestClassifier()
scoring = make_scorer(RF_metric, sample_value = df_val['sample_value'].to_list())

gs = GridSearchCV(clf, param_dict, cv=cv_list, scoring=scoring,n_jobs=4)
gs.fit(X,Y) 

错误是ValueError: arrays must all be same length

标签: pythonmachine-learningscikit-learngrid-search

解决方案


由于升级您的版本有帮助,似乎问题在于return_train_score过去默认为True,因此您确实scoring通过了训练集但具有验证的sample_value.

一种解决方案(例如,如果您仍然想要训练分数,或者想要切换到 kfold 交叉验证,这将有所帮助)是不使用便利功能make_scorer。它只是返回一个带有签名的可调用对象,(estimator, X, y)其中较大的“分数”更好。您可以编写自己的此类可调用对象,然后您可以访问所有内容X(包括列sample_value!),而不仅仅是估算器的预测。


推荐阅读