python - 如何使用 scikit-learn 可视化两个类的边界/决策函数
问题描述
我是机器学习的新手,所以我仍然不明白如何在词袋案例中可视化两个类之间的边界。
我发现下面的例子来绘制数据
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
import matplotlib.pyplot as plt
newsgroups_train = fetch_20newsgroups(subset='train',
categories=['alt.atheism', 'sci.space'])
pipeline = Pipeline([
('vect', CountVectorizer()),
('tfidf', TfidfTransformer()),
])
X = pipeline.fit_transform(newsgroups_train.data).todense()
pca = PCA(n_components=2).fit(X)
data2D = pca.transform(X)
plt.scatter(data2D[:,0], data2D[:,1], c=newsgroups_train.target)
plt.show()
在我的项目中,我使用 SVC 估算器
clf = SVC(random_state=241, kernel = 'linear')
clf.fit(X,newsgroups_train.target)
我尝试使用示例 http://scikit-learn.org/stable/auto_examples/svm/plot_iris.html 但它在文本分类案例中不起作用
那么如何在这个图中添加两个类的边界呢?
谢谢!
解决方案
问题是您只需要选择 2 个特征来创建二维决策曲面图。我将提供 2 个示例。第一个使用iris
数据,第二个使用your
数据。
在这两种情况下,我只选择 2 个特征来创建绘图。
使用 iris 数据的示例 1:
from sklearn.svm import SVC
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
iris = datasets.load_iris()
X = iris.data[:, :2] # we only take the first two features.
y = iris.target
def make_meshgrid(x, y, h=.02):
x_min, x_max = x.min() - 1, x.max() + 1
y_min, y_max = y.min() - 1, y.max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
return xx, yy
def plot_contours(ax, clf, xx, yy, **params):
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
out = ax.contourf(xx, yy, Z, **params)
return out
model = svm.SVC(kernel='linear')
clf = model.fit(X, y)
fig, ax = plt.subplots()
# title for the plots
title = ('Decision surface of linear SVC ')
# Set-up grid for plotting.
X0, X1 = X[:, 0], X[:, 1]
xx, yy = make_meshgrid(X0, X1)
plot_contours(ax, clf, xx, yy, cmap=plt.cm.coolwarm, alpha=0.8)
ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=20, edgecolors='k')
ax.set_ylabel('y label here')
ax.set_xlabel('x label here')
ax.set_xticks(())
ax.set_yticks(())
ax.set_title(title)
ax.legend()
plt.show()
使用您的数据的示例 2:
from sklearn.svm import SVC
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
import matplotlib.pyplot as plt
newsgroups_train = fetch_20newsgroups(subset='train',
categories=['alt.atheism', 'sci.space'])
pipeline = Pipeline([('vect', CountVectorizer()), ('tfidf', TfidfTransformer())])
X = pipeline.fit_transform(newsgroups_train.data).todense()
# Select ONLY 2 features
X = np.array(X)
X = X[:, [0,1]]
y = newsgroups_train.target
def make_meshgrid(x, y, h=.02):
x_min, x_max = x.min() - 1, x.max() + 1
y_min, y_max = y.min() - 1, y.max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
return xx, yy
def plot_contours(ax, clf, xx, yy, **params):
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
out = ax.contourf(xx, yy, Z, **params)
return out
model = svm.SVC(kernel='linear')
clf = model.fit(X, y)
fig, ax = plt.subplots()
# title for the plots
title = ('Decision surface of linear SVC ')
# Set-up grid for plotting.
X0, X1 = X[:, 0], X[:, 1]
xx, yy = make_meshgrid(X0, X1)
plot_contours(ax, clf, xx, yy, cmap=plt.cm.coolwarm, alpha=0.8)
ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=20, edgecolors='k')
ax.set_ylabel('y label here')
ax.set_xlabel('x label here')
ax.set_xticks(())
ax.set_yticks(())
ax.set_title(title)
ax.legend()
plt.show()
结果
重要的提示:
在第二种情况下,情节并不好,因为我们只随机选择了 2 个特征来创建它。使其变得更好的一种方法如下:您可以使用univariate ranking method
(例如 ANOVA F 值测试)并从您最初拥有的特征中找到最佳top-2
特征。22464
然后使用这些top-2
你可以创建一个很好的分离表面图。
推荐阅读
- c# - Blazor WebAssembly:同一组件渲染上的多个路由
- antd - antd 菜单图标 API 不工作。不渲染
- javascript - 开玩笑模拟和设置默认行为
- javascript - 如何使用 Angular 在日期范围内创建搜索过滤器?
- python - Altair 中的并排箱线图
- angular - 导出数组接口以避免“类型上不存在属性'push'的正确方法......”在Angular中
- java - Angular + Spring Boot跨域过滤器不起作用
- ios - 与会者出现 Swift 问题,无法获取姓名、电子邮件等...我不断收到错误消息
- amazon-web-services - 使用 SSM 的 CloudWatch 代理,其中实例未显示在托管实例中
- c# - 为什么此应用程序尝试连接到 SQL Server 而不是 SQLite?