首页 > 解决方案 > 彼此相交的堆叠条形图

问题描述

我有堆栈条形图的以下代码

cols = ['Bug Prediction','Traceability','Security', 'Program Generation & Repair',
        'Performance Prediction','Code Similarity & Clone Detection',
        'Code Navigation & Understanding', 'Other_SE'] 
count_ANN = [2.0,0.0,1.0,0.0,0.0,3.0,5.0,1.0] 
count_CNN = [1.0,0.0,5.0,0.0,1.0,4.0,4.0,0.0]
count_RNN = [1.0,0.0,3.0,1.0,0.0,4.0,7.0,2.0] 
count_LSTM =[3.0,0.0,5.0,3.0,1.0,9.0,15.0,1.0]
count_GNN = [0.0,0.0,1.0,0.0,0.0,3.0,3.0,3.0] 
count_AE =  [0.0,0.0,1.0,3.0,0.0,6.0,11.0,0.0]
count_AM =  [2.0,0.0,1.0,4.0,1.0,4.0,15.0,1.0]
count_other =[1.0,0.0,2.0,2.0,0.0,1.0,3.0,0.0]
b_RNN = list(np.add(count_ANN,count_CNN))
b_LSTM = list(np.add(np.add(count_ANN,count_CNN),count_RNN))
b_AE = list(np.add(np.add(np.add(count_ANN,count_CNN),count_RNN),count_AE))
b_GNN = list(np.add(b_AE,count_GNN))
b_others = list(np.add(b_GNN,count_other))
plt.bar(cols,count_ANN,0.4,label = "ANN")
plt.bar(cols,count_CNN,0.4,bottom=count_ANN,label = "CNN")
plt.bar(cols,count_RNN,0.4,bottom=b_RNN,label = "RNN")
plt.bar(cols,count_LSTM,0.4,bottom =b_LSTM, label = "LSTM")
plt.bar(cols,count_AE,0.4,bottom=b_AE,label = "Auto-Encoder")
plt.bar(cols,count_GNN,0.4,bottom=b_GNN,label = "GNN")
plt.bar(cols,count_other,0.4,bottom=b_others,label = "Others")
#ax.bar(cols, count)
plt.xticks(np.arange(len(cols))+0.1,cols)
fig.autofmt_xdate()
plt.legend()
plt.show()

然后这个输出是重叠的堆栈,如下图所示 在此处输入图像描述

标签: pythonmatplotlib

解决方案


具体问题是b_AE计算错误。(此外,还有一个count_AM没有标签的列表)。

更普遍的问题是,“手动”计算所有这些值很容易出错,并且在发生变化时难以适应。它有助于在循环中编写内容。

numpy 的广播和矢量化的魔力让您可以初始化bottom为单个零,然后使用 numpy 的加法来添加计数。

要使 x 轴更整洁,您可以将单个单词放在单独的行上。此外,plt.tight_layout()尝试确保所有文本都很好地融入情节。

import matplotlib.pyplot as plt
import numpy as np

cols = ['Bug Prediction', 'Traceability', 'Security', 'Program Generation & Repair',
        'Performance Prediction', 'Code Similarity & Clone Detection',
        'Code Navigation & Understanding', 'Other_SE']
count_ANN = [2.0, 0.0, 1.0, 0.0, 0.0, 3.0, 5.0, 1.0]
count_CNN = [1.0, 0.0, 5.0, 0.0, 1.0, 4.0, 4.0, 0.0]
count_RNN = [1.0, 0.0, 3.0, 1.0, 0.0, 4.0, 7.0, 2.0]
count_LSTM = [3.0, 0.0, 5.0, 3.0, 1.0, 9.0, 15.0, 1.0]
count_GNN = [0.0, 0.0, 1.0, 0.0, 0.0, 3.0, 3.0, 3.0]
count_AE = [0.0, 0.0, 1.0, 3.0, 0.0, 6.0, 11.0, 0.0]
count_AM = [2.0, 0.0, 1.0, 4.0, 1.0, 4.0, 15.0, 1.0]
count_other = [1.0, 0.0, 2.0, 2.0, 0.0, 1.0, 3.0, 0.0]

all_counts = [count_ANN, count_CNN, count_RNN, count_LSTM, count_GNN, count_AE, count_AM, count_other]
all_labels = ["ANN", "CNN", "RNN", "LSTM", "GNN", "Auto-Encoder", "AM", "Others"]

cols = ["\n".join(c.split(" ")) for c in cols]
cols = [c.replace("&\n", "& ") for c in cols]

bottom = 0
for count_i, label in zip(all_counts, all_labels):
    plt.bar(cols, count_i, 0.4, bottom=bottom, label=label)
    bottom += np.array(count_i)

# plt.xticks(np.arange(len(cols)) + 0.1, cols)
plt.tick_params(axis='x', labelrotation=45, length=0)
plt.legend()
plt.tight_layout()
plt.show()

结果图

PS:要使条形图与图例的顺序相同,您可以从顶部开始绘制它们:

bottom = np.sum(all_counts, axis=0)
for count_i, label in zip(all_counts, all_labels):
    bottom -= np.array(count_i)
    plt.bar(cols, count_i, 0.4, bottom=bottom, label=label)

推荐阅读