machine-learning - 无法为多标签分类器进行堆叠
问题描述
我正在研究一个多标签文本分类问题(目标标签总数 90)。数据分布具有长尾和类别不平衡以及大约 10 万条记录。我正在使用 OAA 策略(一对一)。我正在尝试使用 Stacking 创建一个合奏。
文本特征:(HashingVectorizer
特征数 2**20,字符分析器)
TSVD 降低维度(n_components=200)。
text_pipeline = Pipeline([
('hashing_vectorizer', HashingVectorizer(n_features=2**20,
analyzer='char')),
('svd', TruncatedSVD(algorithm='randomized',
n_components=200, random_state=19204))])
feat_pipeline = FeatureUnion([('text', text_pipeline)])
estimators_list = [('ExtraTrees',
OneVsRestClassifier(ExtraTreesClassifier(n_estimators=30,
class_weight="balanced",
random_state=4621))),
('linearSVC',
OneVsRestClassifier(LinearSVC(class_weight='balanced')))]
estimators_ensemble = StackingClassifier(estimators=estimators_list,
final_estimator=OneVsRestClassifier(
LogisticRegression(solver='lbfgs',
max_iter=300)))
classifier_pipeline = Pipeline([
('features', feat_pipeline),
('clf', estimators_ensemble)])
错误
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-41-ad4e769a0a78> in <module>()
1 start = time.time()
----> 2 classifier_pipeline.fit(X_train.values, y_train_encoded)
3 print(f"Execution time {time.time()-start}")
4
3 frames
/usr/local/lib/python3.6/dist-packages/sklearn/utils/validation.py in column_or_1d(y, warn)
795 return np.ravel(y)
796
--> 797 raise ValueError("bad input shape {0}".format(shape))
798
799
ValueError: bad input shape (89792, 83)
解决方案
StackingClassifier
目前不支持多标签分类。您可以通过查看fit
参数的形状值来了解这些功能,例如此处。
解决方案是将 OneVsRestClassifier 包装器放在StackingClassifier
各个模型之上。
例子:
from sklearn.datasets import make_multilabel_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.svm import LinearSVC
from sklearn.ensemble import StackingClassifier
from sklearn.multiclass import OneVsRestClassifier
X, y = make_multilabel_classification(n_classes=3, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=0.33,
random_state=42)
estimators_list = [('ExtraTrees', ExtraTreesClassifier(n_estimators=30,
class_weight="balanced",
random_state=4621)),
('linearSVC', LinearSVC(class_weight='balanced'))]
estimators_ensemble = StackingClassifier(estimators=estimators_list,
final_estimator = LogisticRegression(solver='lbfgs', max_iter=300))
ovr_model = OneVsRestClassifier(estimators_ensemble)
ovr_model.fit(X_train, y_train)
ovr_model.score(X_test, y_test)
# 0.45454545454545453
推荐阅读
- sql - SAP SQL IF 语句
- spring-integration - Spring Integration:Http请求->响应->处理->如果有更多数据可用->循环并重复,直到没有收到数据
- tomcat9 - tomcat 9 第二个实例记录到默认 catalina.out 位置
- docker - Docker 映像无法正确下载每个文件
- android - android项目中两个库的重复类
- arrays - 如何在查询之前在 ReactJS 中隐藏列表
- c++ - c++20 范围视图到向量
- java - 如果我通过片段对象抛出函数是干净的代码
- r - 向 ggplot2 ggridges (joyplot) 添加辅助轴以显示每个 bin 中的计数
- python - 替换熊猫系列中的负值