python - 如何使用 MNIST 数据集实现超参数
问题描述
我目前正在 Jupyter notebook 中运行一个程序来对 MNIST 数据集进行分类。我正在尝试使用 KNN 分类器来执行此操作,并且运行需要一个多小时。我是分类器和超参数的新手,似乎没有任何关于如何正确实现其中一个的像样的教程。谁能给我一些关于如何使用超参数进行分类的提示?我已经搜索并看到了 GridSearchCv 和 RandomizedSearchCV。从查看他们的示例来看,他们似乎选择了不同的属性名称并更改为他们的代码所需的名称。如果数据只是手写数字,我不明白如何为 MNIST 数据集做到这一点。看到只有数字,在这种情况下是否不需要超参数?这是我目前仍在运行的代码。
# To support both python 2 and python 3
from __future__ import division, print_function, unicode_literals
# Common imports
import numpy as np
import os
# to make this notebook's output stable across runs
np.random.seed(42)
# To plot pretty figures
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
# Where to save the figures
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "classification"
def save_fig(fig_id, tight_layout=True):
image_dir = os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID)
if not os.path.exists(image_dir):
os.makedirs(image_dir)
path = os.path.join(image_dir, fig_id + ".png")
print("Saving figure", fig_id)
if tight_layout:
plt.tight_layout()
plt.savefig(path, format='png', dpi=300)
def sort_by_target(mnist):
reorder_train = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[:60000])]))[:, 1]
reorder_test = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[60000:])]))[:, 1]
mnist.data[:60000] = mnist.data[reorder_train]
mnist.target[:60000] = mnist.target[reorder_train]
mnist.data[60000:] = mnist.data[reorder_test + 60000]
mnist.target[60000:] = mnist.target[reorder_test + 60000]
try:
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, cache=True)
mnist.target = mnist.target.astype(np.int8) # fetch_openml() returns targets as strings
sort_by_target(mnist) # fetch_openml() returns an unsorted dataset
except ImportError:
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')
mnist["data"], mnist["target"]
mnist.data.shape
X, y = mnist["data"], mnist["target"]
X.shape
y.shape
#select and display some digit from the dataset
import matplotlib
import matplotlib.pyplot as plt
some_digit_index = 7201
some_digit = X[some_digit_index]
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap = matplotlib.cm.binary,
interpolation="nearest")
plt.axis("off")
save_fig("some_digit_plot")
plt.show()
#print some digit's label
print('The ground truth label for the digit above is: ',y[some_digit_index])
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
#random shuffle
import numpy as np
shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]
from sklearn.model_selection import cross_val_predict
from sklearn.neighbors import KNeighborsClassifier
y_train_large = (y_train >= 7)
y_train_odd = (y_train % 2 == 1)
y_multilabel = np.c_[y_train_large, y_train_odd]
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_multilabel)
knn_clf.predict([some_digit])
y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_multilabel, cv=3, n_jobs=-1)
f1_score(y_multilabel, y_train_knn_pred, average="macro")
解决方案
KNN 最流行的超参数是n_neighbors
,即您考虑将标签分配给新点的最近邻居数。默认情况下,它设置为 5,但它可能不是最佳选择。因此,通常最好为您的特定问题找到最佳选择。
这就是您如何为您的示例找到最佳超参数的方法:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
param_grid = {"n_neighbors" : [3,5,7]}
KNN=KNeighborsClassifier()
grid=GridSearchCV(KNN, param_grid = param_grid , cv = 5, scoring = 'accuracy', return_train_score = False)
grid.fit(X_train,y_train)
这样做是将您的 KNN 模型的性能与您设置的不同值进行比较n_neighbors
。那么当你这样做时:
print(grid.best_score_)
print(grid.best_params_)
它将向您展示最佳性能得分是多少,以及它实现了哪些参数选择。
所有这些都与您使用 MNIST 数据这一事实无关。您可以将此方法用于任何其他分类任务,只要您认为 KNN 可能是您的任务的明智选择(这对于图像分类可能是有争议的)。从一项任务到另一项任务的唯一变化是超参数的最佳值。
PS:我建议不要使用该y_multilabel
术语,因为这可能指的是特定的分类任务,其中每个数据点可能有多个标签,而 MNIST 不是这种情况(每个图像一次只代表一个数字)。
推荐阅读
- perl - 在 Mojolicious/Minion::Job 上设置“完成”事件
- python - 让 Spyder 将 cwd 设置为文件的位置
- nativescript - 在 Nativescript Vue 中对对象数组进行排序的正确方法是什么?
- c++ - 如何显示从 TCPsokcet 接收的图像并在 QML 中显示
- asynchronous - 计算在有和没有异步的情况下在 akka 流中完成的时间流
- android - 使用 Retrofit POST 请求获取对象数组
- javascript - discordjs message.guild.roles.size 在 serverinfo 命令中返回 null
- mysql - 带列的 MySQL 格式字符串
- jenkins - extendedEmail 中的矩阵触发模式 - Jenkins 作业 DSL
- react-native - 更改反应开发工具的端口