python - logit 和 sklearn 管道的一种热编码
问题描述
我正在尝试使用 Python 中的Dalex包来可视化二进制 logit 模型的某些特征。
我从这里的示例书中复制了一段代码 (整个第五个代码单元),但现在我不太确定应该如何解释结果......
在我使用创建的基本 logit 模型中,我statsmodels
为每个类别手动选择了一个参考水平变量,以避免多重共线性(这意味着模型的所有结果都相对于参考水平进行解释)。
但是当我使用上面链接中的一段代码(也复制到这篇文章下面)时,它首先在 中创建一些管道对象sklearn
,one-hot 对分类变量进行编码,然后将管道对象拟合到数据并在Dalex解释器作为要解释的模型。
问题是,当我使用model_profile()
Dalex 中的函数时,它应该输出一个图表,显示变量对预测的其他条件不变的影响,我不知道如何解释结果,因为似乎一个分类变量中的所有值包含在图表中。
例如,该模型显示了“性别”分类变量对男性和女性平均预测的影响......
这也显示了一条名为“平均预测”的水平线,但“平均预测”是什么?它是根据男性作为参考水平计算的,还是女性?
我对结果的含义感到非常困惑......有人可以澄清一下吗?model_profile()
我尝试使用的功能也在笔记本中进行了说明。谢谢!
我复制的一段代码:
numerical_features = ['age', 'fare', 'sibsp', 'parch']
numerical_transformer = Pipeline(
steps=[
('imputer', SimpleImputer(strategy='median')),
('scaler', StandardScaler())
]
)
categorical_features = ['gender', 'class', 'embarked']
categorical_transformer = Pipeline(
steps=[
('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
('onehot', OneHotEncoder(handle_unknown='ignore'))
]
)
preprocessor = ColumnTransformer(
transformers=[
('num', numerical_transformer, numerical_features),
('cat', categorical_transformer, categorical_features)
]
)
classifier = MLPClassifier(hidden_layer_sizes=(150,100,50), max_iter=500, random_state=0)
clf = Pipeline(steps=[('preprocessor', preprocessor),
('classifier', classifier)])
clf.fit(X, y)
exp = dx.Explainer(clf, X, y)
解决方案
为什么会这样?
发生这种情况是因为默认情况下,sklearn
'sOneHotEncoder
会对数据中的每个类别进行一次热转换。然而,对于像 logit 这样的线性模型,通常最好将其中一个类别排除在外,以避免多重共线性,并使您的结果可以相对于参考点进行解释。在这种情况下,您需要更改编码器的默认设置。
例子
您可以通过设置来实现这一点drop="first"
,它会删除一个热编码过程的第一类。下面的示例说明了这将如何在一个简单的示例中起作用。在这里,“女性”类别从一个热门编码中删除,只有“男性”类别被编码,这将返回您期望的结果。请注意,这也适用于非二进制特征。
from sklearn.preprocessing import OneHotEncoder
X = pd.DataFrame({"gender":["male","female","female","male"]})
OHE = OneHotEncoder(drop="first")
OHE.fit_transform(X).toarray()
#[[1.],
# [0.],
# [0.],
# [1.]]
OHE.get_feature_names()
#['x0_male']
你需要做什么
因此,您需要在代码中更改的只是管道定义中的以下行:
'onehot', OneHotEncoder(drop='first', handle_unknown='ignore')
推荐阅读
- xamarin.ios - Xamarin.UITest:如何选择当前节点的下一个元素
- java - 添加/删除页面时,Android FragmentStatePagerAdapter 未正确更新 ViewPager
- security - Hacking & Seucrity : 订阅 CMS 和框架升级
- c# - 车速表指针旋转统一
- android - 无限滚动列表视图 - 在构建期间调用 setState() 或 markNeedsBuild
- java - 从 MainActivity 访问一个单独的进程
- angular - ng6:在 html 中动态形成变量名
- javascript - 对 flatList 中的项目进行排序
- django - Django inlineformset 如何不显示 id
- angular-material - mat-select disabled 没有从 angular5 的范围内拾取变量