首页 > 解决方案 > 多索引数据帧中每个子帧上的熊猫点积

问题描述

我有以下数据df_matrixlevel_1并且level_2是多索引:

|level_1|level_2|value_1|value_2|value_3|
|-------|-------|-------|-------|-------|
|a      |w      |1      |2      |3      |
|       |y      |4      |5      |6      |
|       |y      |4      |5      |6      |
|       |z      |7      |8      |9      |
|b      |w      |11     |21     |31     |
|       |x      |41     |51     |61     |
|       |y      |41     |51     |61     |
|       |z      |71     |81     |91     |

df_column,id是索引:

ID 价值
值_1 0.1
价值_2 0.2
值_3 0.3

有没有一种聪明的方法可以在每个子帧上进行点积而不显式循环?

我是这样做的,但想知道是否有更可爱的隐式方式,谢谢,约翰

import pandas as pd
# set up matrix data
df_matrix = pd.DataFrame([(1, 2, 3),
                   (4, 5, 6),
                   (4, 5, 6),
                   (7, 8, 9),
                   (11, 21, 31),
                   (41, 51, 61),
                   (41, 51, 61),
                   (71, 81, 91)],
                  index=[['a', 'a', 'a', 'a', 'b', 'b', 'b', 'b'], ['w', 'x', 'y', 'z', 'w', 'x', 'y', 'z']],
                  columns=('value_1','value_2','value_3'))

# BTW can I do this rename in constructor?
df_matrix.index.rename(['level_1','level_2'], inplace=True)


# set up column data
df_column = pd.DataFrame([('value_1', 0.1), ('value_2', 0.2), ('value_3',0.3)],
                columns=('level_2', 'factor'))

df_column.set_index('level_2', inplace=True)


# loop each sub frame and do matrix multiplication
df_result = pd.DataFrame()
for l1, new_df in df_matrix.groupby(level=0):
    new_df.reset_index(level=0, inplace=True, drop=True)
    df_column.rename(columns={df_column.columns[0] : l1}, inplace=True)
    df_scores = new_df.dot(df_column)
    df_result = pd.concat([df_result, df_scores], axis=1)
    
# result:
df_result.T

#level_2 w    x    y    z
#a       1.4  3.2  3.2  5.0
#b       14.6 32.6 32.6 50.6

标签: pandasmatrixindexingmultiplication

解决方案


您可以直接使用点函数;它将在公共索引上对齐;之后,它是一个简单的 unstack、droplevel 和重命名。

(
    df_matrix.dot(df_column)
    .unstack()
    .droplevel(0, axis=1)
    .rename_axis(index=None, columns=None)
)

     w       x       y      z
a   1.4     3.2      3.2    5.0
b   14.6    32.6    32.6    50.6

推荐阅读