首页 > 解决方案 > 如何比较数据并从 pandas 的 multiIndex 数据框中选择 TOP 2?

问题描述

如何比较数据并从 pandas 的 multiIndex 数据框中选择 TOP 2 或 TOP5?你在这个例子中看到,如果foo只得到一条记录,它将只选择一条。但如果有 3 条记录,则会选择 TOP2 记录。

例如:

arrays = [np.array(['bar', 'bar', 'bar', 'bar', 'baz','baz', 'baz', 'qux', 'qux','qux', 'qux','foo']),
          np.array(['AA', 'AB', 'AC','AD', 'BA', 'BB', 'BC', 'CA', 'CB', 'CC', 'CD', 'DA'])]
df = pd.DataFrame(np.random.randn(12, 1), index=arrays)
df

出去:

         0
bar AA  -0.754077
    AB   0.924327
    AC   0.146192
    AD  -0.718730
baz BA  -0.143378
    BB   1.098409
    BC   0.703452
qux CA   0.729626
    CB   0.232755
    CC   0.827796
    CD   0.914639
foo DA  -0.289108

最后,我想这样选择:

         0
bar AB   0.924327
    AC   0.146192     
baz BB   1.098409
    BC   0.703452
qux CC   0.827796
    CD   0.914639
foo DA  -0.289108

标签: pythonpandas

解决方案


利用:

np.random.seed(234)
arrays = [np.array(['bar', 'bar', 'bar', 'bar', 'baz','baz', 'baz', 'qux', 'qux','qux', 'qux','foo']),
          np.array(['AA', 'AB', 'AC','AD', 'BA', 'BB', 'BC', 'CA', 'CB', 'CC', 'CD', 'DA'])]
df = pd.DataFrame(np.random.randn(12, 1), index=arrays)
print (df)
               0
bar AA  0.818792
    AB -1.043551
    AC  0.350901
    AD  0.921578
baz BA -0.087382
    BB -3.128885
    BC -0.969733
qux CA  0.934666
    CB  0.043866
    CC  1.425216
    CD -0.557063
foo DA  0.926824

解决方案SeriesGroupBy.nlargest

s = df.groupby(level=0)[0].nlargest(2).reset_index(level=0, drop=True)
print (s)
bar  AD    0.921578
     AA    0.818792
baz  BA   -0.087382
     BC   -0.969733
foo  DA    0.926824
qux  CC    1.425216
     CA    0.934666
Name: 0, dtype: float64

如果需要避免排序MultiIndex

df1 = (df.groupby(level=0, sort=False)[0]
       .nlargest(2)
       .reset_index(level=0, drop=True)
       .to_frame())
print (df1)

               0
bar AD  0.921578
    AA  0.818792
baz BA -0.087382
    BC -0.969733
qux CC  1.425216
    CA  0.934666
foo DA  0.926824

另一种解决方案,在pandas 0.23.0+中使用sort_valuesand GroupBy.head

df.index.names = ['lvl1','lvl2']
df.columns = ['a']
s = df.sort_values(['lvl1', 'a'], ascending=[True, False]).groupby(level=0).head(2)
print (s)
                  a
lvl1 lvl2          
bar  AD    0.921578
     AA    0.818792
baz  BA   -0.087382
     BC   -0.969733
foo  DA    0.926824
qux  CC    1.425216
     CA    0.934666

推荐阅读