首页 > 解决方案 > ax.get_yticks() 在 matplotlib 中对轴进行转换

问题描述

[编辑] 我正在尝试将热图与箱线图混合。这个想法是以离散的间隔划分箱形图并相应地着色。更清楚数据:它是同一分钟内同时使用的令牌的每周数据(因此同一周有多行相关)。样本数据如下:

|    | start_hour          | week       |   TokenUsed |
|---:|:--------------------|:-----------|------------:|
|  0 | 2019-12-19 20:20:00 | 2019-12-22 |           8 |
|  1 | 2019-12-19 20:21:00 | 2019-12-22 |           8 |
|  2 | 2019-12-19 20:22:00 | 2019-12-22 |           8 |
|  3 | 2019-12-19 20:23:00 | 2019-12-22 |           8 |
|  4 | 2019-12-19 20:24:00 | 2019-12-22 |           8 |
...

| 43370 | 2020-03-11 11:40:00 | 2020-03-15 |           5 |
| 43371 | 2020-03-11 11:41:00 | 2020-03-15 |           5 |
| 43372 | 2020-03-11 11:42:00 | 2020-03-15 |           5 |
| 43373 | 2020-03-11 11:43:00 | 2020-03-15 |           5 |
| 43374 | 2020-03-11 11:44:00 | 2020-03-15 |           5 

然后我使用以下内容生成箱线图:

df= minute_by_minute.set_index("week")
percentile = 95

x,y = [],[]
fig, ax = plt.subplots(figsize = (12,6))
for date in  df.index.unique():
    y.append(my_boxplot_stats(df.loc[date]["TokenUsed"], percents=[100-percentile,percentile], labels=[date], whis="min/max")[0])
    x.append(date)


data_box = pd.DataFrame(y)

coverage = (minute_by_minute["TokenUsed"] < data_box['q3'].max()).sum()/(len(df["TokenUsed"]))
ax.set_title(f"Feature:{feature} : The max {data_box['q3'].max()} at {data_box.set_index('label')['q3'].idxmax()}\n It covers {coverage:.2%} ")
bplot = ax.bxp(y,patch_artist=True)

rectangles = []
ax.yaxis.set_minor_locator(AutoMinorLocator(2))


color_map = minute_by_minute.groupby(by=["week","TokenUsed"])["TokenUsed"].count()

for patch,q1,q3,label in zip(bplot['boxes'], data_box["q1"],data_box["q3"],data_box["label"]):
    # print(f"Label {label}")

    verts = patch.get_verts().copy()
    verts = ax.transAxes.inverted().transform(verts) 
    print(f"Verts : {verts}")
    height = (verts[3] - verts[0]).sum()
    width = (verts[1] - verts[0]).sum()

    items = len(color_map.loc[label].loc[q1:q3])
    height_per_unit = height/items    
    initial_verts = verts[0].copy()
    
    
    
   
    color_value = (color_map.loc[label]/color_map.loc[label].sum())

    count = 0
    for value, minutes  in color_map.loc[label].loc[q1:q3].to_frame().iterrows():
        count += 1
        rect = patches.Rectangle((initial_verts[0],initial_verts[1]),
        0.5, 
        height_per_unit,
        linewidth=0,
        facecolor= sns.light_palette("red",as_cmap=True)(color_value.loc[value]),
        zorder=3000)
        rectangles.append(rect)
        

        ax.add_patch(rect)

        initial_verts = ax.transAxes.inverted().transform(rect.get_verts())
        initial_verts = initial_verts[3]
        

ax.axhline(maxs[feature])
ax.axhline(find_spot, ls="--", c="k",alpha = 0.5, dash_capstyle="round")
plt.xticks(rotation=30)

plt.savefig(f"{feature}_{100-percentile}_{percentile}.png")
plt.show()

这给出了以下结果:

箱线图 + 热图

但是,绘制在箱线图上的框的高度由 (height_box_plot/number_of_intervals) 定义。

|   TokenUsed |   TokenUsed |
|------------:|------------:|
|           5 |         685 |
|           8 |          20 |
|          10 |        1835 |
|          15 |         335 |
|          16 |         595 |
|          21 |          65 |
|          23 |         130 |
|          24 |         270 |
|          26 |           5 |
|          29 |          40 |
|          31 |         130 |
|          32 |         210 |

第一周绘制的方框

我需要的是使框与刻度线匹配(对值进行分组)。为此,我尝试使用ax.get_yticks()before 循环,但这显然“转换”了轴并给出了 this 奇怪的图。我说的变换是什么意思?

如果我在没有 ax.get_yticks() 的情况下得到补丁的顶点 (bplot("boxes")) 并且得到两个不同的值:

# Without ax.get_yticks()
    Verts : [[ 0.75  5.  ]
     [ 1.25  5.  ]
     [ 1.25 32.  ]
     [ 0.75 32.  ]
     [ 0.75  5.  ]]

# With ax.get_yticks()
    Verts : [[0.01923077 0.07997699]
     [0.05769231 0.07997699]
     [0.05769231 0.39067894]
     [0.01923077 0.39067894]
     [0.01923077 0.07997699]]

最小的例子如下!您可以切换ax.get_yticks()以验证箱线图的顶点发生了什么。


#%%
import time
from functools import partial
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.cbook import _reshape_2D
from matplotlib.collections import PatchCollection
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

def my_boxplot_stats(X, whis=1.5, bootstrap=None, labels=None,
                  autorange=False, percents=[25, 75]):
    """Function to generate the statistics of a box plot
    """
    def _bootstrap_median(data, N=5000):
        # determine 95% confidence intervals of the median
        M = len(data)
        percentiles = [2.5, 97.5]

        bs_index = np.random.randint(M, size=(N, M))
        bsData = data[bs_index]
        estimate = np.median(bsData, axis=1, overwrite_input=True)

        CI = np.percentile(estimate, percentiles)
        return CI

    def _compute_conf_interval(data, med, iqr, bootstrap):
        if bootstrap is not None:
            # Do a bootstrap estimate of notch locations.
            # get conf. intervals around median
            CI = _bootstrap_median(data, N=bootstrap)
            notch_min = CI[0]
            notch_max = CI[1]
        else:

            N = len(data)
            notch_min = med - 1.57 * iqr / np.sqrt(N)
            notch_max = med + 1.57 * iqr / np.sqrt(N)

        return notch_min, notch_max

    # output is a list of dicts
    bxpstats = []

    # convert X to a list of lists
    X = _reshape_2D(X, "X")

    ncols = len(X)
    if labels is None:
        labels = itertools.repeat(None)
    elif len(labels) != ncols:
        raise ValueError("Dimensions of labels and X must be compatible")

    input_whis = whis
    for ii, (x, label) in enumerate(zip(X, labels)):

        # empty dict
        stats = {}
        if label is not None:
            stats['label'] = label

        # restore whis to the input values in case it got changed in the loop
        whis = input_whis

        # note tricksyness, append up here and then mutate below
        bxpstats.append(stats)

        # if empty, bail
        if len(x) == 0:
            stats['fliers'] = np.array([])
            stats['mean'] = np.nan
            stats['med'] = np.nan
            stats['q1'] = np.nan
            stats['q3'] = np.nan
            stats['cilo'] = np.nan
            stats['cihi'] = np.nan
            stats['whislo'] = np.nan
            stats['whishi'] = np.nan
            stats['med'] = np.nan
            continue

        # up-convert to an array, just to be safe
        x = np.asarray(x)

        # arithmetic mean
        stats['mean'] = np.mean(x)

        # median
        med = np.percentile(x, 50)
        ## Altered line
        q1, q3 = np.percentile(x, (percents[0], percents[1]))

        # interquartile range
        stats['iqr'] = q3 - q1
        if stats['iqr'] == 0 and autorange:
            whis = 'range'

        # conf. interval around median
        stats['cilo'], stats['cihi'] = _compute_conf_interval(
            x, med, stats['iqr'], bootstrap
        )

        # lowest/highest non-outliers
        if np.isscalar(whis):
            if np.isreal(whis):
                loval = q1 - whis * stats['iqr']
                hival = q3 + whis * stats['iqr']
            elif whis in ['range', 'limit', 'limits', 'min/max']:
                loval = np.min(x)
                hival = np.max(x)
            else:
                raise ValueError('whis must be a float, valid string, or list '
                                 'of percentiles')
        else:
            loval = np.percentile(x, whis[0])
            hival = np.percentile(x, whis[1])

        # get high extreme
        wiskhi = np.compress(x <= hival, x)
        if len(wiskhi) == 0 or np.max(wiskhi) < q3:
            stats['whishi'] = q3
        else:
            stats['whishi'] = np.max(wiskhi)

        # get low extreme
        wisklo = np.compress(x >= loval, x)
        if len(wisklo) == 0 or np.min(wisklo) > q1:
            stats['whislo'] = q1
        else:
            stats['whislo'] = np.min(wisklo)

        # compute a single array of outliers
        stats['fliers'] = np.hstack([
            np.compress(x < stats['whislo'], x),
            np.compress(x > stats['whishi'], x)
        ])

        # add in the remaining stats
        stats['q1'], stats['med'], stats['q3'] = q1, med, q3

    return bxpstats


#### INPUT DATA #####
np.random.seed(10)
data = pd.DataFrame(pd.date_range("2020-01-01","2020-02-01", freq="1min"),columns=["start_hour"])
data["TokenUsed"] = np.random.normal(15,2, data.shape[0])
data["TokenUsed"] = data["TokenUsed"].astype(int)
minute_by_minute = data.groupby([pd.Grouper(key="start_hour",freq="1W"),"start_hour"]).sum()
minute_by_minute.index.rename(level=0,names="week",inplace=True)
minute_by_minute.reset_index(inplace=True)
percentile = 95
toggle_get_yticks = False


############ BOX PLOT GENERATION
df = minute_by_minute.set_index("week")
x,y = [],[]
fig, ax = plt.subplots(figsize = (12,6))
for date in  df.index.unique():

    y.append(my_boxplot_stats(df.loc[date]["TokenUsed"], percents=[100-percentile,percentile], labels=[date], whis="min/max")[0])
    x.append(date)

data_box = pd.DataFrame(y)

coverage = (df["TokenUsed"] < data_box['q3'].max()).sum()/(len(df["TokenUsed"]))
ax.set_title(f" The max {data_box['q3'].max()} at {data_box.set_index('label')['q3'].idxmax()}\n It covers {coverage:.2%} ")
bplot = ax.bxp(y,patch_artist=True)
rectangles = []
ax.yaxis.set_minor_locator(AutoMinorLocator(2))
if toggle_get_yticks:
    ax.get_yticks()

color_map = minute_by_minute.groupby(by=["week","TokenUsed"])["TokenUsed"].count()
######## COLORING THE BOXPLOT
for patch,q1,q3,label in zip(bplot['boxes'], data_box["q1"],data_box["q3"],data_box["label"]):
    verts = patch.get_verts().copy()
    verts = ax.transAxes.inverted().transform(verts) 
    print(f"Verts : {verts}")
    height = (verts[3] - verts[0]).sum()
    width = (verts[1] - verts[0]).sum()

    items = len(color_map.loc[label].loc[q1:q3])
    height_per_unit = height/items
    
    initial_verts = verts[0].copy()
    
    color_value = (color_map.loc[label]/color_map.loc[label].sum())

    count = 0
    for value, minutes  in color_map.loc[label].loc[q1:q3].to_frame().iterrows():
        count += 1
        rect = patches.Rectangle((initial_verts[0],initial_verts[1]),
        0.5, 
        height_per_unit,
        linewidth=0,
        facecolor= sns.light_palette("red",as_cmap=True)(color_value.loc[value]),
        zorder=3000)
        rectangles.append(rect)
        
        ax.add_patch(rect)
        
        initial_verts = ax.transAxes.inverted().transform(rect.get_verts())
        initial_verts = initial_verts[3]
       
plt.xticks(rotation=30)

plt.show()

# %%

有没有办法阻止或扭转这种转变?

标签: pythonpandasmatplotlib

解决方案


让我们看看我是否理解正确。

我正在计算每周令牌的数量,并使用可自定义的 binsize 进行分类。我不确定你的意思是如何完成分箱。为了绘制热图,我正在创建一个使用框轮廓裁剪的图像。

#### INPUT DATA #####
np.random.seed(10)
data = pd.DataFrame(pd.date_range("2020-01-01","2020-02-01", freq="1min"),columns=["start_hour"])
data["TokenUsed"] = np.random.normal(15,2, data.shape[0])
data["TokenUsed"] = data["TokenUsed"].astype(int)
minute_by_minute = data.groupby([pd.Grouper(key="start_hour",freq="1W"),"start_hour"]).sum()
minute_by_minute.index.rename(level=0,names="week",inplace=True)
minute_by_minute.reset_index(inplace=True)




############ BOX PLOT GENERATION
df = minute_by_minute.set_index("week")
x,y = [],[]
for date in  df.index.unique():
    y.append(my_boxplot_stats(df.loc[date]["TokenUsed"], percents=[100-percentile,percentile], labels=[date], whis="min/max")[0])
    x.append(date)

data_box = pd.DataFrame(y)

coverage = (df["TokenUsed"] < data_box['q3'].max()).sum()/(len(df["TokenUsed"]))
ax.set_title(f" The max {data_box['q3'].max()} at {data_box.set_index('label')['q3'].idxmax()}\n It covers {coverage:.2%} ")


fig, ax = plt.subplots()
bplot = ax.bxp(y,patch_artist=True, widths=widths,
               medianprops=dict(color='k', ls='--', lw=2))
ax.yaxis.set_minor_locator(AutoMinorLocator(2))

####### EDITS START HERE
binSize = 2
cmap = sns.light_palette("red",as_cmap=True)
widths = 0.8

xlim = ax.get_xlim()
ylim = ax.get_ylim()

# get counts of tokens
token_count = minute_by_minute.groupby(by=["week","TokenUsed"])["TokenUsed"].count()

minToken, maxToken = minute_by_minute['TokenUsed'].agg(['min','max'])
minToken = int(np.floor(minToken/binSize)*binSize)
maxToken = int(np.ceil(maxToken/binSize)*binSize)

binned_tokens = minute_by_minute.groupby(['week',pd.cut(minute_by_minute['TokenUsed'], bins=range(minToken,maxToken+1,binSize))])['TokenUsed'].count()

vmin = binned_tokens.min()
vmax = binned_tokens.max()

for i,((w,gr),box) in enumerate(zip(binned_tokens.groupby(level=0),bplot['boxes'])):
    y = np.vstack([gr.values,gr.values])
    im = ax.imshow(y.T, aspect='auto', extent=(i+1-widths/2,i+1+widths/2,minToken,maxToken), cmap=cmap, vmin=vmin, vmax=vmax)
    box.set_facecolor('none')
    im.set_clip_path(box)

ax.set_xlim(xlim)
ax.set_ylim(ylim)

fig.colorbar(im, label='# tokens per bin')
fig.autofmt_xdate()

在此处输入图像描述


推荐阅读