首页 > 解决方案 > Python matplotlib:图例给出了错误的散点结果

问题描述

我正在尝试使用不同的降维技术来可视化时尚 MNIST 数据集,并且我还想使用所谓的 real_labels 将图例附加到结果图片中,它告诉标签的真实名称。对于时尚 MNIST 真实标签是:

real_labels = ['t-shirt','trouser','pullover','dress','coat','sandal','shirt','sneaker','bag','ankle boot']

我正在以下功能中进行绘图部分:

def Draw_datasamples_to_figure(X_scaled, labels, axis):
    y = ['${}$'.format(i) for i in labels]
    num_cls = len(list(set(labels)))
    for (X_plot, Y_plot, y1, label1) in zip(X_scaled[:,0], X_scaled[:,1], y, labels):
        axis.scatter(X_plot, Y_plot, color=cm.gnuplot(int(label1)/num_cls),label=y1, marker=y1, s=60)

,其中 X_scaled 表示 x 和 y 坐标,标签是整数 (0-9) 用于类信息,axis 表示将在哪个子图窗口中绘制图片。

使用以下命令绘制图例:

ax3.legend(real_labels, loc='center left', bbox_to_anchor=(1, 0.5))

一切似乎都很好,直到传说被描绘出来。如下图所示,图例中选择的数字不是从 0 到 9,而是任意的。

在此处输入图像描述

我知道问题可能出在分散部分,我应该以另一种方式实现它,但我希望仍然有一些我想念的简单的东西可以修复我的实现。我不想在代码中定义标记和名称的地方使用手工制作的图例,因为我还有其他具有不同类和真实标签名称的数据集。提前致谢!

标签: pythonmatplotliblegend

解决方案


推荐阅读