首页 > 解决方案 > matplotlib 中径向条形图的绘图行为不一致

问题描述

我正在尝试在 Python/matplotlib 中创建一个圆形条形图。我的代码在我的桌面上运行良好(使用 Spyder),但是即使在相同版本的 Python(3.5)和 Spyder(3.3.1)中,程序的输出在不同的计算机上也是完全不同的。圆形条形图分成一个段,我不明白为什么。

它在我的桌面上看起来像什么(意味着看起来像)

它在其他计算机上的样子

代码从电子表格中读取数据。它基本上是提取国家和百分比值,然后将区域用于颜色。

示例输入:data.xlsx

我已经尝试确保 Python 的版本是相同的,并且遍历了每一行代码,我认为这与我创建 thetagrid(距底部约 10 行)的调用有关,它的行为因计算机而异。这是图表从圆形变为其奇怪的分段行为的地方。

import pandas as pd
import pylab as pl
import numpy as np
import matplotlib.pyplot as plt

#plot params
bottom = 0
max_height = 10
DPI = 400
figsize = DPI*8
MaxPercentage = 0.12
MinPercentage = -0.12
Padlength = 45 #This is the maximum country name length

for sheetcount in range(0,2):

    #read in data
    xl = pd.ExcelFile('data.xlsx')
    sheet_names = xl.sheet_names
    currentsheet = sheet_names[sheetcount]
    datafortitle = pd.read_excel('data.xlsx',sheetname = currentsheet)
    ChartTitle = datafortitle.iloc[0,1]
    rawdata = pd.read_excel('data.xlsx',sheetname = currentsheet,usecols = [0,1,2,3], skiprows = 3)
    rawdata.columns = ['regions','countries','count','value']

    #populate regions (remove gaps)
    for i in range(0,len(rawdata.index)):
        if pd.isnull(rawdata.get_value(i, 'regions')): 
            rawdata.at[i, 'regions'] = rawdata.get_value(i-1,'regions') 


    #remove rows where count isnt 3
    rawdata = rawdata[~pd.isnull(rawdata).any(axis=1)]

    # Get names of indexes for which columns dont have enough data
    indexNames = rawdata[ rawdata['count'] <= 2].index

    # Delete these row indexes
    rawdata.drop(indexNames , inplace=True)
    rawdata = rawdata.reset_index(drop = True)

    fulldata = rawdata

    #cut into respective groups
    Regions = fulldata.iloc[:,0].astype(np.str)
    Countries_Lower = fulldata.iloc[:,1].astype(np.str)
    Countries = Countries_Lower.str.upper()
    ActualSalary = fulldata.iloc[:,3].astype(np.float)
    ActualSalaryStr = fulldata.iloc[:,3].astype(np.str)

    N = len(ActualSalary)

    #Set colours for regions
    Colours = ["" for x in range(N)]
    i = -1
    for x in Regions:
        i += 1
        x = str(x)
        if x == 'Africa':
            Colours[i] = '#490E3D'
        elif x == 'APAC':
            Colours[i] = '#E97F02'
        elif x == 'Europe':
            Colours[i] = '#F8CA00'        
        elif x == 'LATAM':
            Colours[i] = '#88C100'
        elif x == 'Middle East':
            Colours[i] = '#30C4C9'
        elif x == 'North America':
            Colours[i] = '#E40D2C'
        else:
            Colours[i] = 'black'        


    TrimmedSalary = ActualSalary.copy()
    TrimLocation = np.zeros(len(ActualSalary))

    #Trim values to be plotted to be less than or equal to MaxPercentage
    for i in range(0,len(ActualSalary)):
        if TrimmedSalary[i] > MaxPercentage:
            TrimmedSalary[i] = MaxPercentage
            TrimLocation [i] = 1
        if TrimmedSalary[i] < MinPercentage:
            TrimmedSalary[i] = MinPercentage
            TrimLocation [i] = 2


    theta = np.linspace(0.0, 2 * np.pi, N, endpoint=False) - np.pi/N
    r = max_height*TrimmedSalary
    width = (2*np.pi) / N

    with plt.rc_context({'font.cursive':'Textile','axes.edgecolor':'#8a8a8a', 'xtick.color':'#8a8a8a', 'figure.facecolor':'white'}):# 'font.family':'cursive','font.sans-serif':'Geneva'}):

        f = plt.figure(figsize=(figsize/DPI,figsize/DPI))                 
        ax = f.add_subplot(111, polar = True)

        #remove y axis clutter, setup x axis.
        ax.xaxis.grid(color = '#8a8a8a', linestyle = '-')
        ax.yaxis.grid(False)
        ax.set_yticklabels([])
        ax.set_axisbelow(True)
        ax.FontName = 'Sans'

        #add country labels
        ax.set_xticks(np.linspace(0,2*np.pi,N+1))

        bars = ax.bar(theta, r, width=width, bottom=bottom, edgecolor = "none", linewidth = 0)

        # Use custom colors and opacity
        i = 0
        for r, bar in zip(r, bars):
            bar.set_facecolor(Colours[i])
            bar.set_alpha(1.0)
            i += 1

        Colours.append('black')

        #Fix plot size
        ax.set_ylim([MinPercentage*max_height,MaxPercentage*max_height])

        ############################################
        #Plot labels
        Valuelabels = ["%.2f" % x for x in (ActualSalary*100)] 
        Valuelabels = [s + " %" for s in Valuelabels]
        maxlen = len(max(Valuelabels, key=len))

        angles = np.linspace(0,2*np.pi - 2*np.pi/N,N)
        angles[np.cos(angles) < 0] = angles[np.cos(angles) < 0] + np.pi
        angles = np.rad2deg(angles)
        angleplace = np.linspace(0,2*np.pi,N+1)
        angleplace = np.rad2deg(angleplace)

        for i in range(0,N):
            if 90 < angleplace[i] <= 270:
                Valuelabels[i] = Valuelabels[i].rjust(maxlen+1)      
            else:
                Valuelabels[i] = Valuelabels[i].ljust(maxlen+1)


        CombinedLabels = ["" for x in range(N)]

        for i in range(0,N):
            if 0 <= angleplace[i] <= 90:
                CombinedLabels[i] = Valuelabels[i] + Countries[i]
            elif 90 < angleplace[i] <= 270:
                CombinedLabels[i] = Countries[i] + Valuelabels[i]          
            else:
                CombinedLabels[i] = Valuelabels[i] + Countries[i]              


        #Pad to the same length
        maxlen = Padlength        
        for i in range(0,N):
            if 90 < angleplace[i] <= 270:
                CombinedLabels[i] = CombinedLabels[i].rjust(maxlen)      
            else:
                CombinedLabels[i] = CombinedLabels[i].ljust(maxlen)    

        ax.set_xticklabels(CombinedLabels, fontsize=10,)
        plt.gcf().canvas.draw()
        labels = []

        radius = 4    
        i = 0
        for label, angle in zip(ax.get_xticklabels(), angles):   
            x = i/N*(2*np.pi)
            #y = 3.1
            y = 2.77

            lab = ax.text(x,y, label.get_text(), color = Colours[i],
                          ha='center', va='center',
                          weight = 'bold', fontsize = 10,
                          fontproperties=prop)
            lab.set_rotation(angle)
            labels.append(lab)
            i += 1


        #Plot gridlines - SOMETHING GOES WRONG HERE
        thetagrid = np.linspace(0.0, 2*np.pi, N+1) - np.pi/N
        (lines,labelsempty)=plt.thetagrids(np.degrees(thetagrid))

        ax.set_xticklabels([])
        ax.set_axisbelow(False)

        #Add central circle
        circle = pl.Circle((0, 0), 0.4, transform=ax.transData._b, fc="white",ec = '#8a8a8a', zorder = 10)
        ax.add_artist(circle)        

        plt.savefig(ChartTitle, dpi = DPI, bbox_inches='tight')
        #plt.show()

真的很抱歉所有的代码,但不知道如何削减它。

有人知道我可能会丢失的东西吗?也许版本不同?thetagrid 的更新?还是我只是完全错误地使用了这一切。

我真的很感激你可能得到的任何帮助,因为我已经努力解决这个问题了几个星期,而且刚刚走到了死胡同。

非常感谢。

标签: pythonmatplotlibplot

解决方案


推荐阅读