python - 网格搜索中具有第三个参数的自定义记分器
问题描述
我需要制作一个计分器,根据三个列表/数组y_true
、y_pred
和sample_value
. 问题是网格搜索中的记分器计算了训练集和验证集的分数,我不知道如何区分它。这就是我尝试做的事情(完整示例):
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer
def RF_metric(y_true, y_pred, sample_value):
dict_temp = {'y_pred': list(y_pred), 'y_true': list(y_true),
'sample_value': sample_value}
df_temp = pd.DataFrame(dict_temp)
df_temp['daily_score'] = df_temp[['y_pred', 'y_true', 'sample_value']].apply(
lambda row: row[2] if row[0] == row[1] else -row[2], axis=1)
df_temp['cum_score'] = df_temp['daily_score'].cumsum()
final_score = df_temp['cum_score'].to_list()[-1]
return final_score
param_dict = {'n_estimators': [100, 150, 200],
'max_depth': [5, 10, 15],
}
dates = pd.date_range(start='2020-01-01', end='2020-10-01')
df = pd.DataFrame({'Date': dates, 'A':np.random.rand(len(dates)), 'B':np.random.rand(len(dates)),
'label':np.random.choice([0,1],len(dates)), 'sample_value':np.random.rand(len(dates))})
train_start = pd.to_datetime('2020-01-01')
train_end = pd.to_datetime('2020-06-01')
val_start = train_end
val_end = pd.to_datetime('2020-07-01')
df_train = df[(train_start <= df['Date']) & (df['Date'] < train_end)]
df_val = df[(val_start <= df['Date']) & (df['Date'] < val_end)]
cv_list = [(list(df_train.index), list(df_val.index))]
X = df[['A', 'B']].values
Y = df[['label']].values.ravel()
clf = RandomForestClassifier()
scoring = make_scorer(RF_metric, sample_value = df_val['sample_value'].to_list())
gs = GridSearchCV(clf, param_dict, cv=cv_list, scoring=scoring,n_jobs=4)
gs.fit(X,Y)
错误是ValueError: arrays must all be same length
解决方案
由于升级您的版本有帮助,似乎问题在于return_train_score
过去默认为True
,因此您确实scoring
通过了训练集但具有验证的sample_value
.
一种解决方案(例如,如果您仍然想要训练分数,或者想要切换到 kfold 交叉验证,这将有所帮助)是不使用便利功能make_scorer
。它只是返回一个带有签名的可调用对象,(estimator, X, y)
其中较大的“分数”更好。您可以编写自己的此类可调用对象,然后您可以访问所有内容X
(包括列sample_value
!),而不仅仅是估算器的预测。
推荐阅读
- sql-server - TSQL - 将数据从一个表复制到另一个表
- php - 基于当前日期的 ebay api 过滤器列表
- html - css中边框样式的结束
- c# - 无法使用脚本任务将更新的值分配给 ssis 变量
- javascript - 如何遍历一个对象并将金额汇总到每种付款类型?
- python - 为 COCO 数据集生成 TFRecord
- r - 在同一张图上使用 R 将两个变量绘制为线
- scala - 使用 Option 作为输入参数定义 Spark scala UDF
- simpleitk - 在sitk.ReadImage()中应该如何使用参数?
- javascript - JQuery获取请求回调函数未执行