python - 我的分类器在所有测试数据集上给出了 1.0 的准确度(错误的照片除外)
问题描述
有:
- 数据集:115 张 256x256 大小的彩色图像,所有照片都属于ONE 类(卡通人物)。
- 分类器:KNN 和随机森林分类器。
评论:我想制作一个分类器来预测某张照片上的卡通人物,所以我收集了一个数据集,将其数字化并放入分类器的拟合方法中。所以一开始,我选择了SGDClassifier
,但它只适用于数据集中的 2 个或更多类。于是选择了KNN和随机森林分类器。
问题:当我尝试测试我准备好的分类器时,我在每张照片上都得到了1.0 分(我测试了 1 个对象、1 个另一个对象(另一个卡通人物)和一张黑屏的照片),无论如何它们都有 1.0 分。
有人可以帮我吗?:(我已经被困在这两天了,看不到自己解决问题的方法,我看了很多解决方案,但没有一个对我有用。
数据集:
- 我的数据集 numpy 数组的形状是(115, 196608)并且(例如)我的数据集 numpy 数组中的一张图像如下所示:
- 数据集是一个二维数组,因为分类器只采用一维或二维数组。
代码:不全,仅举个例子
train_data_values = numpy.array([*115 photos*])
train_data_labels = numpy.array([*115 labels*])
# For fact, all my labels equal "1", there is no other value.
# Trying KNN
from sklearn.neighbors import KNeighborsClassifier
KNN_clf = KNeighborsClassifier(**{'n_neighbors': 16, 'weights': 'distance'})
KNN_clf.fit(train_data_values, train_data_labels)
test_im = cv2.imread(DATASET_IMAGES_DIRECTORY + "\\test\\" + "test2.png")
KNN_clf.predict_proba(test_im.reshape(1, 3*256*256)) # Returns array([[1.]])
# Trying Random Forest Classifier
from sklearn.ensemble import RandomForestClassifier
RF_clf = RandomForestClassifier()
RF_clf.fit(train_data_values, train_data_labels)
test_im = cv2.imread(DATASET_IMAGES_DIRECTORY + "\\test\\" + "test.png")
RF_clf.predict_proba(test_im.reshape(1, 3*256*256)) # Returns array([[1.]])
评论:我查看了我的 numpy 数据集中的图像,因为我认为它们可能被错误地数字化,但是不,它们可以轻松地从数组到图像构建。
KNN 分类器的 PS 参数是随机的,因为我一直在尝试网格搜索来寻找最佳参数,但到处都是 1.0 分。
解决方案
所有分类器都从他们的训练数据中学习他们的分数。大多数分类器(包括随机森林和 KNN)的分数都具有概率意义:它们被调整以尽可能地反映训练数据的概率分布。
因此,如果您的训练数据由 100% 的单个类别组成,那么分类器将以 100% 的概率学习任何样本属于该类别,并且将以绝对置信度预测该类别。
教训:要使用任何分类器,您至少需要两个类,否则,预测或多或少将毫无意义。我的建议是添加负样本,即没有你的目标人的样本,包括:
- 与您和其他漫画中的其他人的图像
- 只有背景和没有人的图像
- 带有一些非动画对象的图像
有一些例外,例如OneClassSVM,它们(可能)能够在单个类上产生有意义的分数。但是,在您使用来自几个不同类别的数据对它们进行测试之前,您永远不会知道它们是否能充分处理您的数据。
推荐阅读
- matlab - Matlab Simscape Toolbox 使用惯性和质量块
- docker - 这是 asp.net core 3.1 内存泄漏的迹象吗
- c# - OPOS 是否仅适用于 x86 .NET 应用程序?
- antlr4 - 如何为 ANTLR 创建文档?
- php - Git 克隆返回“错误:无效路径 'public/C:\Users\My-PC\Documents\Projects\Sample\storage\logs/laravel.log'”
- mongodb - 如何使用聚合获取 Mongo 集合中所有级别的所有键的名称
- javascript - HighCharts 导出/加载图表
- javascript - 我如何在 mongodb 中使用它们的参考来获取分层数据?
- jquery - 如何使用 JQuery 从没有 id 标识符的 img 更改 src 路径?
- c++ - 内核参数类型必须满足哪些确切的约束?