python - 使用 train_test_split 后分类器准确率为 100%
问题描述
我正在研究蘑菇分类数据集(可在此处找到:https ://www.kaggle.com/uciml/mushroom-classification )。
我正在尝试将我的数据拆分为我的模型的训练和测试集,但是如果我使用 train_test_split 方法,我的模型总是可以达到 100% 的准确度。当我手动拆分数据时,情况并非如此。
x = data.copy()
y = x['class']
del x['class']
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33)
model = xgb.XGBClassifier()
model.fit(x_train, y_train)
predictions = model.predict(x_test)
print(confusion_matrix(y_test, predictions))
print(accuracy_score(y_test, predictions))
这会产生:
[[1299 0]
[ 0 1382]]
1.0
如果我手动拆分数据,我会得到更合理的结果。
x = data.copy()
y = x['class']
del x['class']
x_train = x[0:5443]
x_test = x[5444:]
y_train = y[0:5443]
y_test = y[5444:]
model = xgb.XGBClassifier()
model.fit(x_train, y_train)
predictions = model.predict(x_test)
print(confusion_matrix(y_test, predictions))
print(accuracy_score(y_test, predictions))
结果:
[[2007 0]
[ 336 337]]
0.8746268656716418
什么可能导致这种行为?
编辑: 根据要求,我包括切片的形状。
train_test_split:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33)
print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)
结果:
(5443, 64)
(5443,)
(2681, 64)
(2681,)
手动拆分:
x_train = x[0:5443]
x_test = x[5444:]
y_train = y[0:5443]
y_test = y[5444:]
print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)
结果:
(5443, 64)
(5443,)
(2680, 64)
(2680,)
我已经尝试定义我自己的拆分函数,结果拆分也导致 100% 的分类器准确度。
这是拆分的代码
def split_data(dataFrame, testRatio):
dataCopy = dataFrame.copy()
testCount = int(len(dataFrame)*testRatio)
dataCopy = dataCopy.sample(frac = 1)
y = dataCopy['class']
del dataCopy['class']
return dataCopy[testCount:], dataCopy[0:testCount], y[testCount:], y[0:testCount]
解决方案
你在 train_test_split 上很幸运。您手动进行的拆分可能包含最不可见的数据,这比 train_test_split 进行更好的验证,后者在内部对数据进行混洗以拆分它。
为了更好地验证,请使用 K-fold 交叉验证,这将允许使用数据中的每个不同部分作为测试和其余部分作为训练来验证模型的准确性。
推荐阅读
- java - 我在获得输出方面做错了什么?
- c# - 无法添加迁移。找不到具有不变名称“Oracle.ManagedDataAccess.Client”的 ADO.NET 提供程序的实体框架提供程序
- r - 我们如何将顺式调节元件映射到其特定的启动子区域?
- angularjs - Angularjs 版本问题:默认选择禁用选项
- r - R - 循环遍历用户定义模型上的指标列表
- r - R代码无法识别转义字符
- android - 如何从饼图中的数据库中获取数据?
- dart - 使用 video_player 的文件构造函数拒绝权限
- c# - Xamarin android gui在暂停单独的线程时冻结
- android - 在 libgdx 项目中添加 Admob