首页 > 解决方案 > Stratx 图中的双 Y 轴水平线位置访问

问题描述

我想在stratx(https://github.com/parrt/stratx)plot_stratpd 方法产生的图上画一条穿过 0.0 点的水平线。

在这种情况下如何访问左 Y 轴,以便我可以使用y=0.0

from stratx.partdep import *
X = df.drop('user_retained', axis=1)
y = df['user_retained']

plt.figure(figsize=(16,16), dpi= 80, facecolor='w', edgecolor='k')
plot_stratpd(X, y, 'percentage_of_points', 'user_retained', yrange=(-0.3, 0.6), n_trials=10)
plt.tight_layout()
plt.axhline(y=134, alpha=1, linewidth = 2, linestyle = '-')

plt.show()

由stratx 生成的默认图,手动找到轴线的 Y 值

标签: pythonmatplotlib

解决方案


设置一个Axes并将其传递给plot_stratpd. 然后,您可以使用此 Axes 在常规数据坐标处绘制水平线:

fig,ax = plt.subplots(figsize=(16,16), dpi= 80, facecolor='w', edgecolor='k')
plot_stratpd(X, y, 'percentage_of_points', 'user_retained', yrange=(-0.3, 0.6), n_trials=10, ax=ax)
ax.axhline(y=0, alpha=1, linewidth = 2, linestyle = '-')

例子:

from sklearn.datasets import load_diabetes
from stratx.partdep import *
import matplotlib.pyplot as plt

diabetes = load_diabetes()
df = pd.DataFrame(diabetes.data, columns=diabetes.feature_names)
df['y'] = diabetes.target
X = df.drop('y', axis=1)
y = df['y']

fig,ax = plt.subplots()

plot_stratpd(X, y, 'bmi', 'y', n_trials=10, ax=ax)
ax.axhline(0)

plt.show()

在此处输入图像描述


推荐阅读