catboost - 如何在 CatBoostClassifier.fit() 之后获取评估指标?
问题描述
我已经训练了一个分类模型调用CatBoostClassifier.fit()
,还提供了一个eval_set
.
现在,我怎样才能获取评估指标的最佳值,以及在训练期间实现的迭代次数?plot=True
我可以通过设置调用来绘制信息fit()
,但是如何将其分配给变量?
我可以在训练模型调用时做到这一点cv()
,因为它会cv()
返回所需的信息。但CatBoostClassifier.fit()
根据文档不返回任何内容。
这是我用来拟合模型的代码片段:
model = CatBoostClassifier(
random_seed=42,
logging_level='Silent',
eval_metric='Accuracy'
)
model.fit(X_train,
y_train,
cat_features=cat_features_idxs,
eval_set=(X_val, y_val),
plot=True
)
如果我改用以下方法,我将如何设法获取所需的信息cv()
:
cv_data = cv(Pool(X, y, cat_features = cat_features_idxs),
model.get_params(),
fold_count = 5,
plot=True)
print('Validation accuracy (best average among cross-validation folds) is {} obtained at step {}.'.format(np.max(cv_data['test-Accuracy-mean']), np.argmax(cv_data['test-Accuracy-mean'])))
解决方案
1)只需计算训练数据的分数:
https://stackoverflow.com/a/17954831
model = CatBoostClassifier(
random_seed=42,
logging_level='Silent',
eval_metric='Accuracy'
)
model.fit(X_train,
y_train,
cat_features=cat_features_idxs,
eval_set=(X_val, y_val),
plot=True
)
train_score = model.score(X_train, y_train) # train (learn) score
val_score = model.score(X_val, y_val) # val (test) score
另一种方法是访问输出文件:
model = CatBoostClassifier(
random_seed=42,
logging_level='Silent',
eval_metric='Accuracy',
allow_writing_files=True
)
model.fit(X_train,
y_train,
cat_features=cat_features_idxs,
eval_set=(X_val, y_val),
plot=True
)
import pandas as pd
test_error = pd.read_csv('catboost_info/test_error.tsv', sep='\t')
val_score = test_error.loc[test_error['Accuracy'] == test_error['Accuracy'].max()]['Accuracy'].values[0]
best_iter = int(test_error.loc[test_error['Accuracy'] == test_error['Accuracy'].min()]['iter'].values[0])
train_score = learn_error.loc[learn_error['iter'] == best_iter]['Accuracy'].values[0]
2) 如果你安装了 pandasas_pandas=True
作为参数添加cv
,那么你可以访问 cv_data 作为数据框。例如cv_data['test-Accuracy-mean'].max()
。
https://tech.yandex.com/catboost/doc/dg/concepts/python-reference_cv-docpage/
您也可以像上面那样访问输出文件,在这种情况下,每个折叠都会有一对文件夹。
希望这可以帮助!
推荐阅读
- c# - 如何在使用 NewtonSoft 反序列化对象时使用 JProperty Name 属性?
- php - FPDI 中的 500 内部服务器错误 $pdf = new Fpdi();
- kubernetes - Master2 和 Master3 处于 notReady 状态。似乎与 kubelet.conf 中的用户配置有关
- jenkins - 如何通过 groovy 工作流程从登录的用户 id 中获取 jenkins 中的角色名称
- pjsip - 我应该在后台线程中调用 pjsua_call_make_call (pj::Call::makeCall) 吗?
- xslt - XSLT 标记插入
- javascript - 如何在选择框下拉列表中突出显示某些值,第一次扩展列表
- python - 查找正确的正则表达式以匹配模式并在 python 中提取子字符串
- c - C write() 时间不一致
- google-bigquery - 将某一列中的值替换为具有特定条件 BQ 的同一行中另一列中的值