首页 > 解决方案 > seaborn 子图中每一行的颜色条

问题描述

我看过类似的线程,但没有一个对我有用。我想要做的是显示 2x4 热图,最后每行都有一个联合颜色条,高度与热图相同。

axes用 2x5 创建,这样每行最后都有 4 个矩阵和 1 个颜色条轴。这l2_distances是一个包含四个条目的字典,其中每个键MAPPING_DICT都与一个矩阵相关联(全部相同大小)。

我认为最好的方法是为cbar=False我绘制的每个热图设置并将它们放入轴(0-3),而每行的最后一个热图,即在索引处绘制的轴3将颜色条绘制在cbar_ax=axes[0,4].

import seaborn as sns
import matplotlib.pyplot as plt

 MAPPING_DICT = {"P": 0, "A": 1, "C": 2, "S": 3}

 fig, axes = plt.subplots(2,5, sharex=True, sharey=True)
 for env_name in l2_distances:
     l2_dist_matrix = l2_distances[env_name]
     cbar_flag = True if MAPPING_DICT[env_name]==3 else False
     sns.heatmap(l2_dist_matrix, ax=axes[0, MAPPING_DICT[env_name]], linewidths=0.2, square=True, cbar=cbar_flag, cbar_ax=axes[0,4], cmap="Blues", xticklabels=False, yticklabels=False, robust=True)

但是,这并不完全有效,因为颜色条绘制在(有点)正确的位置但没有标签且高度错误。这只是它看起来像的顶行(添加了一些对颜色图行为没有影响的额外可视化添加),底行基本上是类似的: 2x5 子图的第一行

我已经尝试过明确设置新轴的位置,但这相当乏味,而且效果不太好。有什么我想念的吗?

标签: pythonmatplotlibseabornvisualization

解决方案


主要问题是使用sharex=Truesharey=True会给颜色条与其他子图相同的轴。这与彩条混淆太多了。

如何从 Matplotlib 中的两个轴取消设置 sharex 或 sharey显示了一种删除颜色条共享的方法。这仍然有一些我无法解决的棘手副作用。

下面的解决方案创建子图sharex=Falsesharey=False然后开始共享除颜色条之外的所有子图。由于颜色条不需要像其他子图一样宽,width_ratios可以设置适当的。

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

MAPPING_DICT = {"P": 0, "A": 1, "C": 2, "S": 3}
l2_distances = {"P": np.random.rand(10, 10), "A": np.random.rand(10, 10), "C": np.random.rand(10, 10),
                "S": np.random.rand(10, 10)}

fig, axes = plt.subplots(nrows=2, ncols=5, sharex=False, sharey=False, figsize=(16, 8),
                         gridspec_kw={'width_ratios': [10, 10, 10, 10, 1]})
shax = axes[0, 0].get_shared_x_axes()
shay = axes[0, 0].get_shared_y_axes()
for ax in axes[:, :-1].ravel():
    shax.join(axes[0, 0], ax)
    shay.join(axes[0, 0], ax)
for row in range(axes.shape[0]):
    for env_name in l2_distances:
        l2_dist_matrix = l2_distances[env_name]
        print(env_name, l2_dist_matrix.shape)
        cbar_flag = True if MAPPING_DICT[env_name] == 3 else False
        sns.heatmap(l2_dist_matrix, ax=axes[row, MAPPING_DICT[env_name]], linewidths=0.2, square=True,
                    cbar=cbar_flag, cbar_ax=axes[row, -1], cmap="Blues", xticklabels=False, yticklabels=False,
                    robust=True)
plt.show()

示例图


推荐阅读