classification - 如何计算多类分类中的不平衡准确度度量
问题描述
很抱歉打扰,但我发现了一篇有趣的文章“Mortaz, E. (2020). Imbalance accuracy metric for model selection in multi-class 不平衡分类问题。基于知识的系统, 210, 106490”(https://www .sciencedirect.com/science/article/pii/S0950705120306195),他们在那里计算这个度量(IAM)(公式在论文中,我明白了),但我想问:我怎样才能在R上复制它?
我提前为这个愚蠢的问题道歉。感谢您的关注!
解决方案
文中提供的IAM公式为:IAM公式
其中 cij 是分类器混淆矩阵 (c) 中的元素 (i,j)。k是指分类中的类数(k>=2)。结果表明,该度量可以用作多类模型选择中的单独度量。
下面提供了在 python 中实现 IAM(不平衡准确度指标)的代码:
def IAM(c):
'''
c is a nested list presenting the confusion matrix of the classifier (len(c)>=2)
'''
l = len(c)
iam = 0
for i in range(l):
sum_row = 0
sum_col = 0
sum_row_no_i = 0
sum_col_no_i = 0
for j in range(l):
sum_row += c[i][j]
sum_col += c[j][i]
if j is not i:
sum_row_no_i += c[i][j]
sum_col_no_i += c[j][i]
iam += (c[i][i] - max(sum_row_no_i, sum_col_no_i))/max(sum_row, sum_col)
return iam/l
c = [[2129, 52, 0, 1],
[499, 70, 0, 2],
[46, 16, 0, 1],
[85, 18, 0, 7]]
IAM(c) = -0.5210576475801445
在 R 中实现 IAM(不平衡准确度指标)的代码如下:
IAM <- function(c) {
# c is a matrix representing the confusion matrix of the classifier.
l <- nrow(c)
result = 0
for (i in 1:l) {
sum_row = 0
sum_col = 0
sum_row_no_i = 0
sum_col_no_i = 0
for (j in 1:l){
sum_row = sum_row + c[i,j]
sum_col = sum_col + c[j,i]
if(i != j) {
sum_row_no_i = sum_row_no_i + c[i,j]
sum_col_no_i = sum_col_no_i + c[j,i]
}
}
result = result + (c[i,i] - max(sum_row_no_i, sum_col_no_i))/max(sum_row, sum_col)
}
return(result/l)
}
c <- matrix(c(2129,52,0,1,499,70,0,2,46,16,0,1,85,18,0,7), nrow=4, ncol=4)
IAM(c) = -0.5210576475801445
另一个来自 iris 数据集(3 类问题)和 sklearn 的例子:
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix
X, y = load_iris(return_X_y=True)
clf = LogisticRegression(max_iter = 1000).fit(X, y)
pred = clf.predict(X)
c = confusion_matrix(y, pred)
print('confusion matrix:')
print(c)
print(f'accuarcy : {clf.score(X, y)}')
print(f'IAM : {IAM(c)}')
confusion matrix:
[[50 0 0]
[ 0 47 3]
[ 0 1 49]]
accuarcy : 0.97
IAM : 0.92
推荐阅读
- python - Python - Pebble - 对超时功能的误解
- java - 为什么这给了我java中的算术错误?
- php - 如何在 SQL 中找到最大值并使用 PDO 读入 PHP?
- php - 在php中使用ajax保持会话活跃
- json - 解析嵌套的 json 文件
- c++ - 从缓冲区输入的位
- r - 如何创建按年份组织的每日间隔的时间序列
- vue.js - Vue + SSR | 如何将mixin传输到配置文件?
- dart - Flutter / Dart:尽管设置了状态变量但未设置
- angular - TypeError:window.angular.element 不是 Ionic 4 SQLite 中的函数