首页 > 解决方案 > 如何将图例自定义为直线而不是矩形

问题描述

这是片段:

import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from typing import List,Optional

def draw(df: pd.DataFrame, 
                    colors: List[str],
                    ax_labels: List[str],
                    x_label: Optional[str] = None,
                    y_label: Optional[str] = None,
                    show_y_value=True,
                    dpi=70
                   ):
    sns.set_style("whitegrid", {"axes.grid": False, "font.sans-serif": [cur_font]})


    plt.figure(figsize=(16,8), dpi=dpi)
    ax = sns.barplot(x=df.cate_level2_name, y=df.qgmv, data=df, color=colors[0])

    # add gmv_yoy axis
    ax2 = ax.twinx()
    ax2.set_ylim(0,1)
    plt.plot(ax.get_xticks(),
             df.qgmv_yoy, 
             color=colors[1],
             marker='o'
            ) 

    # add mat_yoy axis
    ax3 = ax.twinx()
    ax3.set_ylim(0,1)
    plt.plot(ax.get_xticks(),
             df.mat_yoy, 
             color=colors[2],
             marker='o'
            ) 

    # adjust labels
    ax.set_xlabel('' if x_label is None else x_label, fontsize=30)
    ax.set_ylabel('' if y_label is None else y_label, fontsize=16)
    ax.xaxis.set_tick_params(labelsize=16)
    ax.yaxis.set_tick_params(labelsize=16)
    ax2.yaxis.set_tick_params(labelsize=16)


    ax.axes.get_yaxis().set_visible(show_y_value)
    ax2.axes.get_yaxis().set_visible(show_y_value)
    ax3.axes.get_yaxis().set_visible(False)    

    # legend
    patches = []
    for i in range(3):
        patches.append(mpatches.Rectangle((0,0), 1, 0.01, color=colors[i]))
    plt.legend(patches, ax_labels, loc='upper right', fontsize=16)

    plt.show()

df = pd.DataFrame(
    {
        'qgmv': [1231251324,45423423423,34324234,2342342355],
        'qgmv_yoy': [0.3,0.6,0.2,0.7],
        'mat_yoy': [0.8,0.6,0.5,0.1],
        'cate_level2_id': [1,2,3,4],
        'cate_level2_name': ['blah1', 'blah2','blah3','blah4']
    })

draw(df, 
    colors = ['#DDDEE0', 'red', 'blue'],
    ax_labels = ['1', '2', '3'],
    show_y_value=False)

结果如下所示: 在此处输入图像描述

如何将 2 和 3 的图例从矩形更改为线?(我尝试将高度设置为 0.01 但似乎不起作用)

标签: pythonmatplotlib

解决方案


我刚刚找到了解决方案:

from matplotlib.lines import Line2D
patches.append(Line2D([0], [0], color=colors[i], linewidth=3, linestyle='-'))

推荐阅读