首页 > 解决方案 > 如何使用带有 matplotlib 的 pandas 创建 3D 绘图

问题描述

我正在为在 matplot lib 上以 3D 呈现数据所需的 pandas 转换而苦苦挣扎。我拥有的数据通常是数字列(通常是时间和一些值)。所以让我们创建一些测试数据来说明。

import pandas as pd

pattern = ("....1...."
           "....1...."
           "..11111.."
           ".1133311."
           "111393111"
           ".1133311."
           "..11111.."
           "....1...."
           "....1....")

# create the data and coords
Zdata = list(map(lambda d:0 if d == '.' else int(d), pattern))
Zinverse = list(map(lambda d:1 if d == '.' else -int(d), pattern))
Xdata = [x for y in range(1,10) for x in range(1,10)]
Ydata = [y for y in range(1,10) for x in range(1,10)]
# pivot the data into columns
data = [d for d in zip(Xdata,Ydata,Zdata,Zinverse)]

# create the data frame 
df = pd.DataFrame(data, columns=['X','Y','Z',"Zi"], index=zip(Xdata,Ydata))
df.head(5)

在此处输入图像描述

编辑:这个数据块是演示数据,通常来自对数据库的查询,在绘图之前可能需要更多的清理和转换。在这种情况下,数据已经对齐,除了多一列我们不需要(Zi)之外没有任何问题。

因此,其中的数字pattern被转移到('Zi' 是反向图像)的 Z 列中的高度数据中,df并且作为数据框,我一直在努力想出这个由 3 个独立操作组成的枢轴方法。我想知道这是否可以更好。

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.cm as cm

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

Xs = df.pivot(index='X', columns='Y', values='X').values
Ys = df.pivot(index='X', columns='Y', values='Y').values
Zs = df.pivot(index='X', columns='Y', values='Z').values

ax.plot_surface(Xs,Ys,Zs, cmap=cm.RdYlGn)

plt.show()

在此处输入图像描述

虽然我有一些工作,但我觉得肯定有比我正在做的更好的方法。在一个大数据集上,我认为做 3 个枢轴是绘制某些东西的昂贵方式。有没有更有效的方法来转换这些数据?

标签: pythonpandasmatplotlib

解决方案


我想您可以通过不使用 pandas(但仅使用 numpy 数组)和使用 numpy 提供的一些便利功能(例如linespacemeshgrid )来避免准备数据期间的一些步骤。

为此,我重写了您的代码,试图保持相同的逻辑和相同的变量名:

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm

pattern = ("....1...."
           "....1...."
           "..11111.."
           ".1133311."
           "111393111"
           ".1133311."
           "..11111.."
           "....1...."
           "....1....")


# Extract the value according to your logic
Zdata = list(map(lambda d:0 if d == '.' else int(d), pattern))

# Assuming the pattern is always a square
size = int(len(Zdata) ** 0.5)

# Create a mesh grid for plotting the surface
Xdata = np.linspace(1, size, size)
Ydata = np.linspace(1, size, size)
Xs, Ys = np.meshgrid(Xdata, Ydata)

# Convert the Zdata to a numpy array with the appropriate shape
Zs = np.array(Zdata).reshape((size, size))

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Plot the surface
ax.plot_surface(Xs, Ys, Zs, cmap=cm.RdYlGn)
plt.show()

推荐阅读