python - Numpy 错误“ValueError:找到昏暗 3 的数组。预计估计器 <= 2。”
问题描述
我是一个使用 scikit-learn 的 ML 的完全新手,我只是想在我花了很多时间学习 ML 的类型等等之后让它工作。
from sklearn import tree
import pandas as pd
import numpy as np
df = pd.read_csv('test.csv')
age = df.Age.to_list()
age = np.array(age).reshape(-1,1)
inc = df.Income.to_list()
inc = np.array(inc).reshape(-1,1)
stud = df.Student.to_list()
stud = np.array(stud).reshape(-1,1)
buy = df.Buy.to_list()
buy = np.array(buy).reshape(-1,1)
X = [age,inc,stud]
y = [[buy]]
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, y)
'''
Income:
1 - high
2 - medium
3 - low
Student:
1 - yes
2 - no
'''
age = 34
inc = 1
stud = 2
pred = clf.predict(age,ince,stud)
print(pred)
但我得到这个错误:
回溯(最后一次调用):文件“D:\Huzefa\Desktop\ML.py”,第 23 行,在 clf = clf.fit(X, y) 文件“C:\Users\Huzefa\AppData\Local\Programs \Python\Python36\lib\site-packages\sklearn\tree_classes.py”,第 894 行,适合 X_idx_sorted=X_idx_sorted) 文件“C:\Users\Huzefa\AppData\Local\Programs\Python\Python36\lib\site- packages\sklearn\tree_classes.py”,第 158 行,适合 check_y_params))文件“C:\Users\Huzefa\AppData\Local\Programs\Python\Python36\lib\site-packages\sklearn\base.py”,行429,在 _validate_data X = check_array(X, **check_X_params) 文件“C:\Users\Huzefa\AppData\Local\Programs\Python\Python36\lib\site-packages\sklearn\utils\validation.py”中,第 73 行, 在 inner_f 返回 f(**kwargs) 文件 "C:\Users\Huzefa\AppData\Local\Programs\Python\Python36\lib\site-packages\sklearn\utils\validation.py",第 642 行,在 check_array % (array.ndim, estimator_name)) ValueError: 找到暗淡的数组3. 估计器预期 <= 2。
如果我可以更正我的脚本以使其正常工作,我将有动力继续使用 ML 继续提供所有帮助,非常感谢!
解决方案
您定义 X 和 y 的方式对我来说似乎过于复杂,这种选择背后有什么具体原因吗?您还可以执行以下操作:
X = df[["Age","Income","Student"]]
y = df.Buy
另外,通过做
clf = clf.fit(X, y)
您正在根据所有可用数据训练您的决策树。如果这是一个训练数据集并且你有一个存储在其他地方的测试数据集,那没关系;如果没有,您需要先拆分数据,以便您可以训练模型并测试所述训练的效率。train_test_split
是一个有用的功能。
推荐阅读
- r - 文件上传后更新 checkboxGroupInput() 选项
- java - IntelliJ 模块与 Java 9 模块有什么关系吗?
- javascript - 使用 PUG 在表格单元格中插入图像
- graphql - 我可以使用指令根据其他字段值计算字段值吗?
- sql - 返回集合的 SQL 查询
- c# - 如何使用 Autofac 注册调解员的 IPipelineBehavior
- c# - 使用 ASP.NET Core 和 Swagger 在 api 请求中出现 404 错误
- android - 如何逐渐隐藏recyclerview滚动条上的项目
- r - 为什么 Substitute(variable_x) 得到一个数值而不是 'name' 对象
- python - 如何抓取谷歌地图标记上显示的数据