python-3.x - 特征数错误。如何解决这个问题?
问题描述
请帮忙,我收到一些功能错误。使用的列是 ID、编码列和整数列。此代码适用于具有相似但更多功能的另一个数据集。使用的功能数量是否太少而无法出现此错误?这是我的代码:
from sklearn.model_selection import train_test_split
num_test = 0.20 # 80-20 split
X_train, X_test, y_train, y_test = train_test_split(X_all, y_all, test_size=num_test, random_state=23)
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import make_scorer, accuracy_score
from sklearn.model_selection import GridSearchCV
clf = RandomForestClassifier()
parameters = {'n_estimators': [4, 6, 9],
'max_features': ['log2', 'sqrt', 'auto'],
'criterion': ['entropy', 'gini'],
'max_depth': [2, 3, 5, 10],
'min_samples_split': [2, 3, 5],
'min_samples_leaf': [1, 5, 8]
}
acc_scorer = make_scorer(accuracy_score)
grid_obj = GridSearchCV(clf, parameters, scoring=acc_scorer)
grid_obj = grid_obj.fit(X_train, y_train)
clf = grid_obj.best_estimator_
clf.fit(X_train, y_train)
ids = data_test['Id']
predictions = clf.predict(data_test.drop('Id', axis=1))
output = pd.DataFrame({'Id': ids, 'Full_Time_Home_Goals': predictions})
print(output.head())
我得到的错误是:
> Traceback (most recent call last):
> File "C:/Users/harsh/PycharmProjects/Kaggle-Machine Learning from Start to Finish with Scikit-Learn/EPL Predicting.py", line 98, in
> <module>
> predictions = clf.predict(data_test.drop('Id', axis=1))
> File "C:\Users\harsh\PycharmProjects\GitHub\venv\lib\site-packages\sklearn\ensemble\_forest.py",
> line 629, in predict
> ValueError: Number of features of the model must match the input. Model n_features is 4 and input n_features is 2
即使我不放弃predictions = clf.predict(data_test.drop('Id', axis=1))
,我仍然会收到错误
样本数据集:
data_train:
Id HomeTeam AwayTeam Full_Time_Home_Goals
0 1 55 440 3
1 2 158 493 2
2 3 178 745 1
3 4 185 410 1
4 5 249 57 2
data_test:
Id HomeTeam AwayTeam
0 190748 284 54
1 190749 124 441
2 190750 446 57
3 190751 185 637
4 190752 749 482
列是它应该工作的方式。为什么不是?
解决方案
推荐阅读
- android - Xamarin App Android 上的输入调度超时错误
- reactjs - 当我调用 DrawerOpen 时,抽屉不会打开到 initialRouteName
- bootstrap-4 - Bootstrap 4 - 具有固定边距的动态居中列数
- javascript - 上传 HTML 格式的图片并获取图片路径
- jmx - 每分钟将数据记录到 grafana
- android - 房间更新后将回收站视图滚动到新位置
- laravel - laravel errno: 150 "外键约束格式不正确
- sql-server - 导出到 CSV/Excel 时如何保持 SQL 格式
- c# - LINQ-to-SQL 插入...选择...其中...在一个查询中
- python - 创建后分配外键值(登录用户)