首页 > 解决方案 > 绘制 MNIST 样本

问题描述

我正在尝试从 MNIST 数据集中绘制 10 个样本。每个数字一个。这是代码:

import sklearn
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets

mnist = datasets.fetch_mldata('MNIST original')
y = mnist.target
X = mnist.data

for i in range(10):
    im_idx = np.argwhere(y == i)[0]
    print(im_idx)
    plottable_image = np.reshape(X[im_idx], (28, 28))
    plt.imshow(plottable_image, cmap='gray_r')
    plt.subplot(2, 5, i + 1)

plt.plot()

出于某种原因,图中跳过了零位。

为什么?

标签: pythonmatplotlibdata-sciencemnist

解决方案


尝试这个:

import sklearn
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets

mnist = datasets.fetch_mldata('MNIST original')
y = mnist.target
X = mnist.data

fig, ax = plt.subplots(2,5)
ax = ax.flatten()
for i in range(10):
    im_idx = np.argwhere(y == i)[0]
    print(im_idx)
    plottable_image = np.reshape(X[im_idx], (28, 28))
    ax[i].imshow(plottable_image, cmap='gray_r')

输出:

在此处输入图像描述


推荐阅读