python - 如何绘制 Python 3 维水平集?
问题描述
我在绘制脑海中的图像时遇到了一些麻烦。我想用支持向量机可视化内核技巧。所以我制作了一些由两个圆(一个内圆和一个外圆)组成的二维数据,它们应该被一个超平面隔开。显然这在二维中是不可能的——所以我将它们转换为 3D。设 n 为样本数。现在我有一个 (n,3)-array (3 columns, n rows) X 数据点和一个 (n,1)-array y 和标签。使用 sklearn 我得到线性分类器
clf = svm.SVC(kernel='linear', C=1000)
clf.fit(X, y)
我已经将数据点绘制为散点图
plt.scatter(X[:, 0], X[:, 1], c=y, s=30, cmap=plt.cm.Paired)
现在我想将分离超平面绘制为曲面图。我的问题是缺少超平面的显式表示,因为决策函数仅通过decision_function = 0
. 因此,我需要绘制一个 4 维对象的级别集(级别 0)。
由于我不是python专家,如果有人可以帮助我,我将不胜感激!而且我知道这并不是使用 SVM 的真正“风格”,但我需要这张图片作为我论文的插图。
编辑:我当前的“代码”
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
from sklearn.datasets import make_blobs, make_circles
from tikzplotlib import save as tikz_save
plt.close('all')
# we create 50 separable points
#X, y = make_blobs(n_samples=40, centers=2, random_state=6)
X, y = make_circles(n_samples=50, factor=0.5, random_state=4, noise=.05)
X2, y2 = make_circles(n_samples=50, factor=0.2, random_state=5, noise=.08)
X = np.append(X,X2, axis=0)
y = np.append(y,y2, axis=0)
# shifte X to [0,2]x[0,2]
X = np.array([[item[0] + 1, item[1] + 1] for item in X])
X[X<0] = 0.01
clf = svm.SVC(kernel='rbf', C=1000)
clf.fit(X, y)
plt.scatter(X[:, 0], X[:, 1], c=y, s=30, cmap=plt.cm.Paired)
# plot the decision function
ax = plt.gca()
xlim = ax.get_xlim()
ylim = ax.get_ylim()
# create grid to evaluate model
xx = np.linspace(xlim[0], xlim[1], 30)
yy = np.linspace(ylim[0], ylim[1], 30)
YY, XX = np.meshgrid(yy, xx)
xy = np.vstack([XX.ravel(), YY.ravel()]).T
Z = clf.decision_function(xy).reshape(XX.shape)
# plot decision boundary and margins
ax.contour(XX, YY, Z, colors='k', levels=[-1, 0, 1], alpha=0.5, linestyles=['--','-','--'])
# plot support vectors
ax.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], s=100,
linewidth=1, facecolors='none', edgecolors='k')
################## KERNEL TRICK - 3D ##################
trans_X = np.array([[item[0]**2, item[1]**2, np.sqrt(2*item[0]*item[1])] for item in X])
fig = plt.figure()
ax = plt.axes(projection ="3d")
# creating scatter plot
ax.scatter3D(trans_X[:,0],trans_X[:,1],trans_X[:,2], c = y, cmap=plt.cm.Paired)
clf2 = svm.SVC(kernel='linear', C=1000)
clf2.fit(trans_X, y)
ax = plt.gca(projection='3d')
xlim = ax.get_xlim()
ylim = ax.get_ylim()
zlim = ax.get_zlim()
### from here i don't know what to do ###
xx = np.linspace(xlim[0], xlim[1], 3)
yy = np.linspace(ylim[0], ylim[1], 3)
zz = np.linspace(zlim[0], zlim[1], 3)
ZZ, YY, XX = np.meshgrid(zz, yy, xx)
xyz = np.vstack([XX.ravel(), YY.ravel(), ZZ.ravel()]).T
Z = clf2.decision_function(xyz).reshape(XX.shape)
#ax.contour(XX, YY, ZZ, Z, colors='k', levels=[-1, 0, 1], alpha=0.5, linestyles=['--','-','--'])
期望的输出
解决方案
您的部分问题已在有关线性内核 SVM的问题中得到解决。这是一个部分答案,因为只有线性内核可以用这种方式表示,也就是说,由于使用线性内核时可以通过估计器访问超平面坐标。
另一种解决方案是找到等值面marching_cubes
该解决方案涉及安装scikit-image
工具包(https://scikit-image.org),该工具包允许从 3D 坐标的网格网格中找到给定值的等值面(这里,我认为 0,因为它表示到超平面的距离) .
在下面的代码(从你的代码中复制)中,我为任何内核实现了这个想法(在示例中,我使用了 RBF 内核),并且输出显示在代码下方。请考虑我关于使用 matplotlib 进行 3D 绘图的脚注,这可能是您的另一个问题。
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
from skimage import measure
from sklearn.datasets import make_blobs, make_circles
from tikzplotlib import save as tikz_save
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
plt.close('all')
# we create 50 separable points
#X, y = make_blobs(n_samples=40, centers=2, random_state=6)
X, y = make_circles(n_samples=50, factor=0.5, random_state=4, noise=.05)
X2, y2 = make_circles(n_samples=50, factor=0.2, random_state=5, noise=.08)
X = np.append(X,X2, axis=0)
y = np.append(y,y2, axis=0)
# shifte X to [0,2]x[0,2]
X = np.array([[item[0] + 1, item[1] + 1] for item in X])
X[X<0] = 0.01
clf = svm.SVC(kernel='rbf', C=1000)
clf.fit(X, y)
plt.scatter(X[:, 0], X[:, 1], c=y, s=30, cmap=plt.cm.Paired)
# plot the decision function
ax = plt.gca()
xlim = ax.get_xlim()
ylim = ax.get_ylim()
# create grid to evaluate model
xx = np.linspace(xlim[0], xlim[1], 30)
yy = np.linspace(ylim[0], ylim[1], 30)
YY, XX = np.meshgrid(yy, xx)
xy = np.vstack([XX.ravel(), YY.ravel()]).T
Z = clf.decision_function(xy).reshape(XX.shape)
# plot decision boundary and margins
ax.contour(XX, YY, Z, colors='k', levels=[-1, 0, 1], alpha=0.5, linestyles=['--','-','--'])
# plot support vectors
ax.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], s=100,
linewidth=1, facecolors='none', edgecolors='k')
################## KERNEL TRICK - 3D ##################
trans_X = np.array([[item[0]**2, item[1]**2, np.sqrt(2*item[0]*item[1])] for item in X])
fig = plt.figure()
ax = plt.axes(projection ="3d")
# creating scatter plot
ax.scatter3D(trans_X[:,0],trans_X[:,1],trans_X[:,2], c = y, cmap=plt.cm.Paired)
clf2 = svm.SVC(kernel='rbf', C=1000)
clf2.fit(trans_X, y)
z = lambda x,y: (-clf2.intercept_[0]-clf2.coef_[0][0]*x-clf2.coef_[0][1]*y) / clf2.coef_[0][2]
ax = plt.gca(projection='3d')
xlim = ax.get_xlim()
ylim = ax.get_ylim()
zlim = ax.get_zlim()
### from here i don't know what to do ###
xx = np.linspace(xlim[0], xlim[1], 50)
yy = np.linspace(ylim[0], ylim[1], 50)
zz = np.linspace(zlim[0], zlim[1], 50)
XX ,YY, ZZ = np.meshgrid(xx, yy, zz)
xyz = np.vstack([XX.ravel(), YY.ravel(), ZZ.ravel()]).T
Z = clf2.decision_function(xyz).reshape(XX.shape)
# find isosurface with marching cubes
dx = xx[1] - xx[0]
dy = yy[1] - yy[0]
dz = zz[1] - zz[0]
verts, faces, _, _ = measure.marching_cubes_lewiner(Z, 0, spacing=(1, 1, 1), step_size=2)
verts *= np.array([dx, dy, dz])
verts -= np.array([xlim[0], ylim[0], zlim[0]])
# add as Poly3DCollection
mesh = Poly3DCollection(verts[faces])
mesh.set_facecolor('g')
mesh.set_edgecolor('none')
mesh.set_alpha(0.3)
ax.add_collection3d(mesh)
ax.view_init(20, -45)
plt.savefig('kerneltrick')
使用 Matplotlib 运行代码会生成以下图像,其中绿色半透明表面表示非线性决策边界。
脚注:使用 matplotlib 进行 3D 绘图
请注意,在某些情况下,Matplotlib 3D 无法管理对象的“深度”,因为它可能与zorder
该对象的“深度”冲突。这就是为什么有时超平面看起来被绘制在点的“顶部”,即使它应该在“后面”。此问题是matplotlib 3d 文档和此答案中讨论的已知错误。
如果您想获得更好的渲染结果,您可能需要使用Matplotlib 开发人员推荐的Mayavi或任何其他 3D Python 绘图库。
推荐阅读
- html - Bootstrap 模态体高度不变
- python - TypeError:填充 TF 数据集对象当前不支持填充稀疏张量的批处理
- react-native - BackHandler 不会关闭 React Native 应用程序
- lucee - Lucee - 启动缓慢
- firebase - 通知点击与否........点击时间戳
- wpf - 如果调用 ListCollectionView.AddNewItem,则不允许出现 DeferRefresh 错误。为什么?
- python - “scipy.optimize.minimize”如何强制系数不为零
- sql - 带有条件的 3 个表的 SQL 查询
- windows - 使用 clip.exe 复制到剪贴板
- python - 如何快速创建具有重复值的数据框