python - 如何在水平堆叠条形图中构建层次结构标签
问题描述
我想在 y 轴上有一个带有层次标签的水平堆积条形图。我搜索了一下,发现了以下很好的示例和代码。
但它适用于垂直堆积条形图。我想将其应用于水平条形图,所以我只是更改了kind='barh'
,但这不起作用。
我设法通过在最后几行中将所有 x 更改为 y 来删除默认的 ylabels。但是在定义的函数中将 x 更改为 y 并没有给我我想要的:层次结构标签仍然在 x 轴上。
任何人都可以帮忙吗?谢谢。
PS:为了让事情不那么混乱,我发布了我从这个问题的第二个答案中找到的原始代码,而不是我试图修改的那个
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from itertools import groupby
def test_table():
data_table = pd.DataFrame({'Room': ['Room A'] * 4 + ['Room B'] * 3,
'Shelf': ['Shelf 1'] * 2 + ['Shelf 2'] * 2 + ['Shelf 1'] * 2 + ['Shelf 2'],
'Staple':['Milk', 'Water', 'Sugar', 'Honey', 'Wheat', 'Corn', 'Chicken'],
'Quantity': [10, 20, 5, 6, 4, 7, 2,],
'Ordered': np.random.randint(0, 10, 7)
})
data_table
def add_line(ax, xpos, ypos):
line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
transform=ax.transAxes, color='black')
line.set_clip_on(False)
ax.add_line(line)
def label_len(my_index,level):
labels = my_index.get_level_values(level)
return [(k, sum(1 for i in g)) for k,g in groupby(labels)]
def label_group_bar_table(ax, df):
ypos = -.1
scale = 1./df.index.size
for level in range(df.index.nlevels)[::-1]:
pos = 0
for label, rpos in label_len(df.index,level):
lxpos = (pos + .5 * rpos)*scale
ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes)
add_line(ax, pos*scale, ypos)
pos += rpos
add_line(ax, pos*scale , ypos)
ypos -= .1
df = test_table().groupby(['Room','Shelf','Staple']).sum()
fig = plt.figure()
ax = fig.add_subplot(111)
df.plot(kind='bar',stacked=True,ax=fig.gca())
#Below 3 lines remove default labels
labels = ['' for item in ax.get_xticklabels()]
ax.set_xticklabels(labels)
ax.set_xlabel('')
label_group_bar_table(ax, df)
fig.subplots_adjust(bottom=.1*df.index.nlevels)
plt.show()
解决方案
你可以做这样的事情。
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import pandas as pd
import numpy as np
data_table = pd.DataFrame({'Room': ['Room A'] * 4 + ['Room B'] * 3,
'Shelf': ['Shelf 1'] * 2 + ['Shelf 2'] * 2 + ['Shelf 1'] * 2 + ['Shelf 2'],
'Staple': ['Milk', 'Water', 'Sugar', 'Honey', 'Wheat', 'Corn', 'Chicken'],
'Quantity': [10, 20, 5, 6, 4, 7, 2, ],
'Ordered': np.random.randint(0, 10, 7)
})
arrays = [list(data_table['Room']), list(data_table['Shelf']), list(data_table['Staple'])]
data_table = data_table.groupby(['Room', 'Shelf', 'Staple']).sum()
index = pd.MultiIndex.from_tuples(list(zip(*arrays)))
df = pd.DataFrame(data_table[['Ordered', 'Quantity']], index=index).T
# plotting
fig = plt.figure()
height_ratios = [len(df[df.columns.levels[0][0]].columns), len(df[df.columns.levels[0][1]].columns)] #i.e. 4, 3
gs = gridspec.GridSpec(nrows=len(df.columns.levels[0]), ncols=1, height_ratios=height_ratios)
ax1 = fig.add_subplot(gs[0,0])
ax2 = fig.add_subplot(gs[1,0], sharex=ax1)
axes = [ax1, ax2]
for i, col in enumerate(df.columns.levels[0]):
print(col)
ax = axes[i]
df[col].T.plot(ax=ax, stacked=True, kind='barh', width=.8)
ax.legend_.remove()
ax.set_ylabel(col, weight='bold')
ax.xaxis.grid(b=True, which='major', color='black', linestyle='--', alpha=.4)
ax.set_axisbelow(True)
for tick in ax.get_xticklabels():
tick.set_rotation(0)
ax.legend()
# make the ticklines invisible
ax.tick_params(axis=u'both', which=u'both', length=0)
plt.tight_layout()
# remove spacing in between
fig.subplots_adjust(wspace=0) # space between plots
plt.show()
推荐阅读
- apache-spark - 如何在 Spark 中有效地抓取多个表的列?
- javascript - 按顺序将附加了 jquery 的 Firebase 数据显示到表中
- jquery - 使用 jquery 动画滚动顶部是一个错误
- amazon-web-services - AWS- AWS command to view lambda function in Command Line
- php - 数组的一部分不通过 foreach 访问
- json - 如何从以下 JSON 数据中提取文件名
- python - Convert a string with arithmetic operation to an interger in python
- android - 如何处理整个应用程序共有的 user_id 、登录令牌等字段
- system-verilog - SystemC 与 SystemVerilog 的仿真速度
- azure - Azure 图形 API 没有响应