首页 > 解决方案 > Keras 中类别子部分的准确度指标

问题描述

我有一个三级分类问题。让我们将它们定义为 0,1 和 2 类。在我的例子中,0 类并不重要——也就是说,任何被归类为 0 类的东西都无关紧要。然而,相关的是仅适用于第 1 类和第 2 类的准确度、精确度、召回率和错误率。我想定义一个准确度指标,它只查看与 1 和 2 相关的数据的子部分,并给我一个衡量标准因为模型正在训练。我不是要求准确性或 f1 或精度/召回的代码——那些我已经找到并且可以自己实现的代码。我要找的代码可以帮助选择类别的一个小节来执行这些指标。在视觉上,使用混淆矩阵: 给定:

>  0   1   2
>0 10  3   4
>1 2   5   1
>2 8   5   9

我只想在训练中对以下子集执行准确性测量:

>  1   2
>1 5   1
>2 5   9

可能的想法:连接一个分类的、argmaxed y_pred 和 argmaxed y_true,删除所有出现 0 的实例,将它们重新解开成一个 one_hot 数组,并对剩下的内容进行简单的二进制精度?

编辑:我试图通过这段代码排除 0 类,但这没有意义。0 类被有效地包装到 1 类中(也就是说,0 和 1 的真阳性最终都被标记为 1)。仍在寻求帮助 - 有人可以帮忙吗?

#this solution does not work :(
def my_acc(y_true, y_pred):
#excluding the 0-category
y_true_cust = y_true[:,np.r_[1:3]]
y_pred_cust = y_pred[:,np.r_[1:3]]
#binary accuracy source code, slightly edited
y_pred_cat = Ker.round(y_pred_cust)
eql_cust = Ker.equal(y_true_cust, y_pred_cust)
return Ker.mean(eql_cust, axis = -1)

@Ashwin Geet D'Sa

correct_guesses_3cat = 10 + 5 + 9
print(correct_guesses_3cat)
24

total_guesses_3cat = 10+3+4+2+5+1+8+5+9
print(total_guesses_3cat)
47

accuracy_3cat = 24/47
print(accuracy_3cat)
51.1 %

correct_guesses_2cat =5 + 9
print(correct_guesses_2cat)
14

total_guesses_2cat = 5+1+5+9
print(total_guesses_2cat)
20

accuracy_2cat = 14/20
print(accuracy_2cat)
70.0 %

标签: pythonmachine-learningkerasmetricstensor

解决方案


推荐阅读