首页 > 解决方案 > Seaborn 热图 - 要显示的行和列统计信息

问题描述

是否可以在 Seaborn 热图的边缘添加行和列统计信息?

因此,对于右侧的每一行,我想显示行平均值(每个月),在年份的底部边缘,我想显示每列的列平均值。

在此处输入图像描述

标签: pythonpandasdataframedatetimeseaborn

解决方案


如果您正在使用这样的数据框:

df = pd.DataFrame({'date': pd.date_range(start = '1949-01-01', end = '1960-12-01', freq = 'MS')})
df['value'] = np.random.randint(100, 600, len(df))
          date  value
0   1949-01-01    202
1   1949-02-01    535
2   1949-03-01    448
3   1949-04-01    370
4   1949-05-01    206
..         ...    ...
139 1960-08-01    238
140 1960-09-01    598
141 1960-10-01    180
142 1960-11-01    491
143 1960-12-01    262

你必须重新塑造pandas.DataFrame.pivot

df['month'] = df['date'].dt.month_name().str.slice(stop = 3).sort_values()
df['year'] = df['date'].dt.year
df = df.pivot(columns = 'year', index = 'month', values = 'value')
year   1949  1950  1951  1952  1953  1954  1955  1956  1957  1958  1959  1960
month                                                                        
Apr     370   472   485   574   463   487   543   101   301   395   479   591
Aug     120   230   260   287   230   341   530   359   450   437   114   238
Dec     314   443   352   545   120   485   519   501   561   509   426   262
Feb     535   558   513   444   545   266   191   459   143   351   351   443
Jan     202   430   591   335   274   428   439   149   317   314   316   108
Jul     288   251   376   575   419   113   363   205   369   336   256   162
Jun     171   459   543   269   343   415   527   153   583   307   140   571
Mar     448   187   393   148   150   373   466   487   261   289   287   228
May     206   199   291   158   154   188   554   489   545   312   592   235
Nov     566   357   121   289   234   152   180   290   555   379   444   491
Oct     221   408   413   370   406   445   305   576   370   152   164   180
Sep     202   249   559   563   584   364   134   409   403   466   400   598

然后,您可以添加一个具有月份平均值的列和一个具有年份平均值的行:

df['month_mean'] = df.mean(axis = 1)
df.loc['year_mean'] = df.mean(axis = 0)
year       1949  1950  1951  1952  1953  1954  1955  1956  1957  1958  1959  1960  month_mean
month                                                                                        
Apr         370   472   485   574   463   487   543   101   301   395   479   591         438
Aug         120   230   260   287   230   341   530   359   450   437   114   238         299
Dec         314   443   352   545   120   485   519   501   561   509   426   262         419
Feb         535   558   513   444   545   266   191   459   143   351   351   443         399
Jan         202   430   591   335   274   428   439   149   317   314   316   108         325
Jul         288   251   376   575   419   113   363   205   369   336   256   162         309
Jun         171   459   543   269   343   415   527   153   583   307   140   571         373
Mar         448   187   393   148   150   373   466   487   261   289   287   228         309
May         206   199   291   158   154   188   554   489   545   312   592   235         326
Nov         566   357   121   289   234   152   180   290   555   379   444   491         338
Oct         221   408   413   370   406   445   305   576   370   152   164   180         334
Sep         202   249   559   563   584   364   134   409   403   466   400   598         410
year_mean   303   353   408   379   326   338   395   348   404   353   330   342         357

或者,您可以使用以下方法进行旋转和计算均值pandas.pivot_table

df = pd.pivot_table(data = df, columns = 'year', index = 'month', values = 'value', margins = True)
year   1949  1950  1951  1952  1953  1954  1955  1956  1957  1958  1959  1960  All
month                                                                             
Apr     370   472   485   574   463   487   543   101   301   395   479   591  438
Aug     120   230   260   287   230   341   530   359   450   437   114   238  299
Dec     314   443   352   545   120   485   519   501   561   509   426   262  419
Feb     535   558   513   444   545   266   191   459   143   351   351   443  399
Jan     202   430   591   335   274   428   439   149   317   314   316   108  325
Jul     288   251   376   575   419   113   363   205   369   336   256   162  309
Jun     171   459   543   269   343   415   527   153   583   307   140   571  373
Mar     448   187   393   148   150   373   466   487   261   289   287   228  309
May     206   199   291   158   154   188   554   489   545   312   592   235  326
Nov     566   357   121   289   234   152   180   290   555   379   444   491  338
Oct     221   408   413   370   406   445   305   576   370   152   164   180  334
Sep     202   249   559   563   584   364   134   409   403   466   400   598  410
All     303   353   408   379   326   338   395   348   404   353   330   342  357

唯一的区别是最后一列和最后一行的名称。
现在您已准备好绘制热图:

fig, ax = plt.subplots()

sns.heatmap(ax = ax, data = df, annot = True, fmt = '.0f')

plt.show()

完整代码

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


df = pd.DataFrame({'date': pd.date_range(start = '1949-01-01', end = '1960-12-01', freq = 'MS')})
df['value'] = np.random.randint(100, 600, len(df))

df['month'] = df['date'].dt.month_name().str.slice(stop = 3).sort_values()
df['year'] = df['date'].dt.year
df = df.pivot(columns = 'year', index = 'month', values = 'value')

df['month_mean'] = df.mean(axis = 1)
df.loc['year_mean'] = df.mean(axis = 0)


fig, ax = plt.subplots()

sns.heatmap(ax = ax, data = df, annot = True, fmt = '.0f')

plt.show()

在此处输入图像描述


或者,您可以更改最后一列和最后一行的颜色图,以提高可见性:

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


df = pd.DataFrame({'date': pd.date_range(start = '1949-01-01', end = '1960-12-01', freq = 'MS')})
df['value'] = np.random.randint(100, 600, len(df))

df['month'] = df['date'].dt.month_name().str.slice(stop = 3).sort_values()
df['year'] = df['date'].dt.year
df = df.pivot(columns = 'year', index = 'month', values = 'value')

df['month_mean'] = df.mean(axis = 1)
df.loc['year_mean'] = df.mean(axis = 0)

df_values = df.copy()
df_values['month_mean'] = float('nan')
df_values.loc['year_mean'] = float('nan')

df_means = df.copy()
df_means.loc[:-1, :-1] = float('nan')


fig, ax = plt.subplots()

sns.heatmap(ax = ax, data = df_values, annot = True, fmt = '.0f', cmap = 'Reds', vmin = df.to_numpy().min(), vmax = df.to_numpy().max())
sns.heatmap(ax = ax, data = df_means, annot = True, fmt = '.0f', cmap = 'Blues', vmin = df.to_numpy().min(), vmax = df.to_numpy().max())

plt.show()

在此处输入图像描述


推荐阅读