首页 > 解决方案 > 使用 Matplotlib 绘制 3D 散点图时如何更改图例文本?

问题描述

我有一个使用以下代码生成的 3D 散点图

import seaborn as sns
import numpy as np

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import ListedColormap
# Create an example dataframe
data = {'th': [1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2],
        'pdvalue': [0.5, 0.5, 0.5, 0.5, 0.2,0.2,0.2,0.2,0.3,0.3,0.4,0.1,1,1.1,3,1],
        'my_val': [1.2,3.2,4,5.1,1,2,5.1,1,2,4,1,3,6,6,2,3],
        'name':['a','b','c','d','a','b','c','d','a','b','c','d','a','b','c','d']}
df = pd.DataFrame(data)
# convert unique str into unique int
order_dict = {k: i for i, k in enumerate ( df ['name'])}
df ['name_int'] = df ['name'].map ( order_dict )
data_np=df.to_numpy()

# generate data

x = data_np[:,0]
y = data_np[:,1]
z = data_np[:,2]

# axes instance
fig = plt.figure(figsize=(10,6))
ax = Axes3D(fig)

# get colormap from seaborn
cmap = ListedColormap(sns.color_palette("husl", 256).as_hex())

# plot
sc = ax.scatter(x, y, z, s=40, c=data_np[:,4], marker='o', cmap=cmap, alpha=1)
ax.set_xlabel('th')
ax.set_ylabel('pdvalue')
ax.set_zlabel('my_val')

# legend
plt.legend(*sc.legend_elements(), bbox_to_anchor=(1.05, 1), loc=2)
plt.show()

这个产品

在此处输入图像描述

在上面,我必须将nameinto类型转换为唯一 acceptinteger的 para 。结果,图例是根据值而不是原始的映射。cax.scatternumbernumericname

我可以知道如何用name而不是数字表示来使用图例吗?

标签: pythonmatplotlib

解决方案


使用 pandas 进行转换和选择可以简化代码。通过分别为每个“名称”绘制散点图,可以为每个“名称”赋予一个标签作为图例。

这是改编后的代码:

import seaborn as sns
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Create an example dataframe
data = {'th': [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2],
        'pdvalue': [0.5, 0.5, 0.5, 0.5, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3, 0.4, 0.1, 1, 1.1, 3, 1],
        'my_val': [1.2, 3.2, 4, 5.1, 1, 2, 5.1, 1, 2, 4, 1, 3, 6, 6, 2, 3],
        'name': ['a', 'b', 'c', 'd', 'a', 'b', 'c', 'd', 'a', 'b', 'c', 'd', 'a', 'b', 'c', 'd']}
df = pd.DataFrame(data)

# axes instance
fig = plt.figure(figsize=(10, 6))
ax = Axes3D(fig, auto_add_to_figure=False)
fig.add_axes(ax)

# find all the unique labels in the 'name' column
labels = np.unique(df['name'])
# get palette from seaborn
palette = sns.color_palette("husl", len(labels))

# plot
for label, color in zip(labels, palette):
    df1 = df[df['name'] == label]
    ax.scatter(df1['th'], df1['pdvalue'], df1['my_val'],
               s=40, marker='o', color=color, alpha=1, label=label)
ax.set_xlabel('th')
ax.set_ylabel('pdvalue')
ax.set_zlabel('my_val')

# legend
plt.legend(bbox_to_anchor=(1.05, 1), loc=2)
plt.show()

带有图例的 3d 散点图


推荐阅读