首页 > 解决方案 > 将每个实例带有错误的预测和真实类/值合并到 Pandas DataFrame 中

问题描述

我正在尝试将分类算法的错误作为属性包含在内,然后在同一数据集(Scikit Wine 数据集)上应用 K-Means 聚类。所需的数据帧由测试数据的实例、基本事实标签和每个实例的算法错误组成。它应该如下表所示:

酒精 苹果酸 真正的班级 预测类 错误
0 14.23 1.71 2.43 1 3 -2
1 13.71 5.65 2.45 2 2 0

这是我的可重现代码:

from sklearn.datasets import load_wine
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

data = load_wine()
df_data = pd.DataFrame(data.data)
df_target = pd.DataFrame(data = data.target)

# Dividing X and y into train and test data (small train data to gain more errors)
X_train, X_test, y_train, y_test = train_test_split(df_data, df_target, test_size=0.40, random_state=2)

# Training a RandomForest Classifier 
model = RandomForestClassifier()
model.fit(X_train, y_train)

# Obtaining predictions
y_hat = model.predict(X_test)

# Converting y_hat from Np to DF
predictions_col = pd.DataFrame()
predictions_col['predictions'] = y_hats.tolist()
predictions_col['true classes'] = y_test


# adding predictions to test data
df_out = pd.merge(df_data, predictions_col, left_index = True, right_index = True)
df_out

输出片段

标签: pythonpandasdataframescikit-learn

解决方案


您正在尝试将您的预测X_test与包含所有数据的数据框合并df_data,因此是 NaN。

首先,在 中定义您的errorpredictions_col

predictions_col['error'] =  predictions_col['true classes'] - predictions_col['predictions'] 

然后将其与X_test

df_out = pd.merge(X_test, predictions_col, left_index = True, right_index = True)
print(df_out)

# result:

        0     1     2     3  ...      12  predictions  true classes  error
12  13.75  1.73  2.41  16.0  ...  1320.0            0           0.0    0.0
23  12.85  1.60  2.52  17.8  ...  1015.0            1           0.0   -1.0
25  13.05  2.05  3.22  25.0  ...   830.0            0           0.0    0.0
35  13.48  1.81  2.41  20.5  ...   920.0            1           0.0   -1.0
13  14.75  1.73  2.39  11.4  ...  1150.0            1           0.0   -1.0
65  12.37  1.21  2.56  18.1  ...   678.0            0           1.0    1.0
48  14.10  2.02  2.40  18.8  ...  1060.0            2           0.0   -2.0
3   14.37  1.95  2.50  16.8  ...  1480.0            0           0.0    0.0
6   14.39  1.87  2.45  14.6  ...  1290.0            1           0.0   -1.0
42  13.88  1.89  2.59  15.0  ...  1095.0            1           0.0   -1.0
2   13.16  2.36  2.67  18.6  ...  1185.0            2           0.0   -2.0
29  14.02  1.68  2.21  16.0  ...  1035.0            0           0.0    0.0
45  14.21  4.04  2.44  18.9  ...  1080.0            1           0.0   -1.0
5   14.20  1.76  2.45  15.2  ...  1450.0            0           0.0    0.0
53  13.77  1.90  2.68  17.1  ...  1375.0            1           0.0   -1.0
41  13.41  3.84  2.12  18.8  ...  1035.0            2           0.0   -2.0
54  13.74  1.67  2.25  16.4  ...  1060.0            1           0.0   -1.0
24  13.50  1.81  2.61  20.0  ...   845.0            0           0.0    0.0
64  12.17  1.45  2.53  19.0  ...   355.0            2           1.0   -1.0
28  13.87  1.90  2.80  19.4  ...   915.0            0           0.0    0.0
14  14.38  1.87  2.38  12.0  ...  1547.0            2           0.0   -2.0
44  13.05  1.77  2.10  17.0  ...   885.0            2           0.0   -2.0
66  13.11  1.01  1.70  15.0  ...   502.0            2           1.0   -1.0
57  13.29  1.97  2.68  16.8  ...  1270.0            0           0.0    0.0
71  13.86  1.51  2.67  25.0  ...   410.0            2           1.0   -1.0
11  14.12  1.48  2.32  16.8  ...  1280.0            0           0.0    0.0
36  13.28  1.64  2.84  15.5  ...   880.0            2           0.0   -2.0
62  13.67  1.25  1.92  18.0  ...   630.0            2           1.0   -1.0
0   14.23  1.71  2.43  15.6  ...  1065.0            0           0.0    0.0
27  13.30  1.72  2.14  17.0  ...  1285.0            0           0.0    0.0

[30 rows x 16 columns]

推荐阅读