python - 如何从 GridsearchCV 获取 feature_importances_
问题描述
我对编程很陌生,这个问题可能很容易解决,但我已经坚持了一段时间,我认为我的方法显然是错误的。正如标题所示,我一直在尝试对我的 RandomForest 预测实施网格搜索,以找到模型的最佳可能参数,然后查看具有最佳参数的模型的最重要特征。我用过的包:
import nltk
from nltk.corpus import stopwords
import pandas as pd
import string
import re
import pickle
import os
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold, cross_val_score
from sklearn.metrics import precision_recall_fscore_support as score
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
经过一些数据清理和预处理后,我进行了这样的网格搜索,其中 x_features 是具有我数据的 tfidfvectorized 特征的 DataFrame:
param = {'n_estimators':[10, 50, 150], 'max_depth':[10, 30, 50, None], 'min_impurity_decrease':[0, 0.01, 0.05, 0.1], 'class_weight':["balanced", None]}
gs = GridSearchCV(rf, param, cv=5, n_jobs=-1)
gs_fit = gs.fit(x_features, mydata['label'])
optimal_param = pd.DataFrame(gs_fit.cv_results_).sort_values('mean_test_score', ascending = False)[0:5]
optimal_param1 = gs_fit.best_params_
我的想法是,也许我可以让自己变得容易,并将最佳参数 1 复制到我的 RandomForestClassifier() 中,并或多或少地适合我的训练数据:
rf = RandomForestClassifier(optimal_param2)
rf_model= rf.fit(x_train, y_train)
但optimal_param2 是一个字典。因此,我认为将其转换为字符串并消除过多的符号( sub : for =, delete {, delete } )会使其工作。这显然失败了,因为 n_estimators、max_depth 等的数字仍然是字符串并且它需要整数。我最终想要实现的是获得最重要功能的输出,或多或少像这样:
top25_features = sorted(zip(rf_model.feature_importances_, x_train.columns),reverse=True)[0:25]
我意识到 gs 已经是一个完整的 RF 模型,但它没有我正在寻找的属性 feature_importances_。我将非常感谢有关如何使其发挥作用的任何想法。
解决方案
一旦你跑了gs_fit=gs.fit(X,y)
,你就拥有了你需要的一切,你不需要做任何再训练。
首先,您可以通过以下方式访问最佳模型:
best_estimator = gs_fit.best_estimator_
这将返回产生最佳结果的随机森林。然后你可以通过做访问这个模型的特征重要性
best_features = best_estimator.feature_importances_
显然,您可以链接这些并直接执行:
best_features = gs_fit.best_estimator_.feature_importances_
推荐阅读
- c# - 触摸键盘隐藏 UI 元素 [Windows 10 和 WPF]
- mysql - 创建一个表,其中包含名为 age 的列,以使用 MYSQL(数据库)从出生日期计算年龄
- android - 如何将usb相机与cameraX一起使用?
- python - 用 pandas 中的多个 Headers 重塑和总结 Excel
- c# - C# Visual Studio 的 Chrome 驱动程序问题
- wordpress - 为 wordpress 站点备份启用文件下载需要哪些 gcp 防火墙设置?
- url - 显示嵌入在 Google Blogger 博客文章中的网址的缩略图
- python - 在python中删除具有负值的字典
- latex - 在词汇表/首字母缩略词索引之前结束乳胶的最后一章
- cypress - 如何遍历属性值并检查它是否包含字符串