首页 > 解决方案 > Pandas:如何为数据集中的所有项目创建 userByItem pivot_table

问题描述

假设,我有一个如下所示的数据集:

User Item Rating
u1   i1   3
u2   i2   4
u3   i3   5
u4   i1   2
u5   i2   1
u5   i4   3
u1   i4   2

我将数据集分成两部分,训练数据集:

User Item Rating
u1   i1   3
u2   i2   4
u3   i3   5

和测试数据集:

User Item Rating
u4   i1   2
u5   i2   1
u5   i4   3
u1   i4   2

如果我使用以下代码从这 2 个拆分的数据集创建 2 个 pivot_table:

 trainPivot = pd.pivot_table(trainData, values='Rating',
                                index=['User'], columns=['Item'])
 testPivot = pd.pivot_table(testData, values='Rating',
                                index=['User'], columns=['Item'])

然后生成的 pivot_tables 看起来像这样,对于训练数据:

       I1      I2         I3
U1     3       Null       Null
U2     Null    4          Null
U3     Null    Null       5

对于测试数据:

       I1      I2    I4      
U4     2       Null  Null   
U5     Null    1     3

但我希望我的 pivot_tables 看起来像这样,对于火车数据:

       I1      I2         I3     I4
U1     3       Null       Null   Null
U2     Null    4          Null   Null
U3     Null    Null       5      Null

对于测试数据:

       I1      I2    I3    I4      
U1     Null    Null  Null  2
U4     2       Null  Null  Null   
U5     Null    1     Null  3

如何使用 pivot_table 方法在 Pandas 数据框中实现这一点。

标签: pythonpandaspivot-table

解决方案


关键是如果项目列不存在,则添加它。

我不确定是什么Null,所以我插入nan以保持默认的 pandas 格式。

import pandas as pd
import numpy as np


data = pd.DataFrame({
    'User': ['u1', 'u2', 'u3', 'u4', 'u5', 'u5', 'u1'],
    'Item': ['i1', 'i2', 'i3', 'i1', 'i2', 'i4', 'i4'],
    'Rating': [3, 4, 5, 2, 1, 3, 2]
})

train_data = data.head(3)
test_data = data.tail(4)

train_pivot = pd.pivot_table(
    train_data, values='Rating', index=['User'], columns=['Item']
)
test_pivot = pd.pivot_table(
    test_data, values='Rating', index=['User'], columns=['Item']
)

unique_items = data['Item'].unique()

for item in unique_items:
    if item not in test_pivot:
        test_pivot[item] = np.nan
    if item not in train_pivot:
        train_pivot[item] = np.nan

# If you want the columns sorted alphabetically
train_pivot = train_pivot.reindex_axis(sorted(train_pivot.columns), axis=1)
test_pivot = test_pivot.reindex_axis(sorted(test_pivot.columns), axis=1)

输出结果:

train_pivot

Item   i1   i2   i3  i4
User                   
u1    3.0  NaN  NaN NaN
u2    NaN  4.0  NaN NaN
u3    NaN  NaN  5.0 NaN

test_pivot

Item   i1   i2  i3   i4
User                   
u1    NaN  NaN NaN  2.0
u4    2.0  NaN NaN  NaN
u5    NaN  1.0 NaN  3.0

推荐阅读