python - 使用班级权重后 F1 分数降低
问题描述
我正在研究一个多类分类用例,数据高度不平衡。通过高度不平衡的数据,我的意思是频率最高的类别和频率最低的类别之间存在巨大差异。因此,如果我继续使用,SMOTE oversampling
那么数据量会大大增加(数据量从 280k 行增加到超过 250 亿行,因为不平衡性太高了)并且实际上不可能将 ML 模型拟合到如此庞大的数据集。同样,我不能使用欠采样,因为这会导致信息丢失。
所以我想compute_class_weight
在创建 ML 模型时使用 from sklearn。
代码:
from sklearn.utils.class_weight import compute_class_weight
class_weight = compute_class_weight(class_weight='balanced',
classes=np.unique(train_df['Label_id']),
y=train_df['Label_id'])
dict_weights = dict(zip(np.unique(train_df['Label_id']), class_weight))
svc_model = LinearSVC(class_weight=dict_weights)
我对测试数据进行了预测,并记录了 , 等指标的结果accuracy
。f1_score
我recall
尝试复制相同但不通过class_weight
,如下所示:
svc_model = LinearSVC()
但是我得到的结果很奇怪。通过后class_weight
的指标比没有的指标差一些class_weight
。
我希望完全相反,因为我正在使用class_weight
它来使模型更好,从而使指标更好。
两个模型的指标之间的差异很小,但与没有f1_score
模型class_weight
相比,模型的差异较小class_weight
。
我还尝试了以下代码段:
svc_model = LinearSVC(class_weight='balanced')
但f1_score
与没有的模型相比,它仍然更少class_weight
。
以下是我获得的指标:
LinearSVC w/o class_weight
Accuracy: 89.02, F1 score: 88.92, Precision: 89.17, Recall: 89.02, Misclassification error: 10.98
LinearSVC with class_weight=’balanced’
Accuracy: 87.98, F1 score: 87.89, Precision: 88.3, Recall: 87.98, Misclassification error: 12.02
LinearSVC with class_weight=dict_weights
Accuracy: 87.97, F1 score: 87.87, Precision: 88.34, Recall: 87.97, Misclassification error: 12.03
我认为使用class_weight
会改善指标,但会恶化指标。为什么会发生这种情况,我该怎么办?如果我不处理不平衡数据可以吗?
解决方案
我如何看待问题
我对您的问题的理解是,您的班级权重方法实际上是在改进您的模型,但您(可能)看不到它。原因如下:
假设您有 10 个 POS 和 1k NEG 样本,并且您有两个模型:M-1 正确预测了所有 NEG 样本(假阴性率 = 0),但仅正确预测了 10 个 POS 样本中的 2 个。M-2 正确预测了 700 个 NEG 和 8 个 POS 样本。从异常检测的角度来看,第二个模型可能是首选,而第一个模型(显然陷入了不平衡问题)具有更高的 f1 分数。
类权重将尝试解决您的不平衡问题,将您的模型从 M-1 转移到 M-2。因此,您的 f1 分数可能会略有下降。但你可能有一个质量更好的模型。
你如何验证我的观点
您可以通过查看混淆矩阵来检查我的观点,以查看 f1 分数是否由于您的主要课程的更多错误分类而降低,以及您的次要课程现在是否有更多的真阳性。此外,您可以专门针对不平衡类测试其他指标。我知道Cohen 的 Kappa也许你看到班级权重实际上增加了 Kappa 分数。
还有一件事:做一些引导或交叉验证,f1 分数的变化可能是由于数据的可变性而没有任何意义
推荐阅读
- node.js - 在 node.js 项目中运行节点控制台是否会加载 package.json 中的所有依赖项?
- flutter - 从 Json 文件中获取数据但地图不起作用
- macos - 为什么在 docker bind mount 中创建的文件的 uid 会因主机操作系统而异?
- html - 联系表 7 - 需要有关简码和价值的帮助
- iphone - 谷歌地图在移动网站上显示错误的位置
- c# - VSTO Outlook 加载项发布为按需加载使其处于非活动状态
- mysql - 如何在窗口 10 中的 MySql 中创建一个返回单个十进制值的函数
- sql - oracle中sql比plsql快吗
- css - 将 JSON 配置文件导入到 styles.css
- javascript - $ 和 document 已定义,但 $(document) 未定义