tensorflow - 在Tensorflow分类中,使用“预测”时标签是如何排序的?
问题描述
我正在使用 MNIST 手写数字数据集来训练 CNN。
训练模型后,我使用这样的预测:
predictions = cnn_model.predict(test_images)
predictions[0]
我得到的输出为:
array([2.1273775e-06, 2.9292005e-05, 1.2424786e-06, 7.6307842e-05,
7.4305902e-08, 7.2301691e-07, 2.5368356e-08, 9.9952960e-01,
1.2401938e-06, 1.2787555e-06], dtype=float32)
在输出中,有 10 个概率,从 0 到 9 的每个数字都有一个概率。但是我怎么知道哪个概率指的是哪个数字?
在这种特殊情况下,数字 0 到 9 的概率按顺序排列。但为什么会这样呢?我没有在任何地方定义。
我尝试查看互联网上其他地方的文档和示例实现,但似乎没有人解决这种特殊行为。
编辑:
对于上下文,我通过以下方式定义了我的训练/测试数据:
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = (np.expand_dims(train_images, axis=-1)/255.).astype(np.float32)
train_labels = (train_labels).astype(np.int64)
test_images = (np.expand_dims(test_images, axis=-1)/255.).astype(np.float32)
test_labels = (test_labels).astype(np.int64)
我的模型包括几个卷积层和池化层,然后是 Flatten 层,然后是具有 128 个神经元的 Dense 层和具有 10 个神经元的输出 Dense 层。
之后,我简单地拟合我的模型并像这样使用预测:
model.fit(train_images, train_labels, batch_size=BATCH_SIZE, epochs=EPOCHS)
predictions = cnn_model.predict(test_images)
我没有看到我在哪里指示我的代码将第一个神经元输出为数字 0,第二个神经元作为数字 1 等如果我想更改输出结果数字的顺序,我该在哪里做呢?这真的让我很困惑。
解决方案
这取决于您在训练期间如何准备标签。对于 MNIST 分类,通常有两种不同的方式:
- One-hot 标签:MNIST 数据中有 10 个标签,因此对于每个示例(图像),您创建一个长度为 10 的标签数组(向量),其中所有元素都为零,除了与输入图像的数字对应的索引正在显示。例如,如果您的输入图像显示数字8,则您的标签在除第8 个索引处以外的所有位置都包含零(例如 [0,0,0,0,0,0,0,0, 1 ,0])。如果您的图像显示数字2,则您的标签将类似于 [0,0, 1 ,0,0,0,0,0,0,0] 等等。
- 稀疏标签:您只需通过显示的数字直接标记每个图像,例如,如果您的图像显示数字8,则您的标签是一个值为 8 的数字。
在这两种情况下,您都可以根据需要选择标签,在 MNIST 分类中,使用标签 0-9 来显示数字 0-9 很直观。
因此,在预测中,索引 0 处的概率用于数字 0,索引 1 用于数字 1,依此类推。
您可以选择以不同的方式准备标签。例如,您可以决定如下显示您的标签:
- 数字 0 的标签:9
- 数字 1 的标签:8
- 数字 2 的标签:7
- 数字 3 的标签:6
- 数字 4 的标签:5
- 数字 5 的标签:4
- 数字 6 的标签:3
- 数字 7 的标签:2
- 数字 8 的标签:1
- 数字 9 的标签:0
您可以以相同的方式训练您的模型,但在这种情况下,预测中的概率将被反转。索引 0 处的概率为数字 9,索引 1 为数字 8,依此类推。
简而言之,您必须使用整数索引定义标签,但由您决定并记住您选择的索引来引用哪个标签/类。
推荐阅读
- reactjs - ReferenceError:未定义 babelHelpers ./node_modules/react-avatar-editor/dist/index.js/
我得到一个 ReferenceError: babelHelpers is not defined,同时导入 react-avatar-editor, 请帮助!
- python - 如何修复 Python Flask 应用程序中的“方法不允许”错误
- google-authentication - ArgumentException:必须提供“ClientId”选项。(参数'ClientId')
- javascript - 我如何知道元素是网格上的哪个图块?
- c - 如何在C中读取混合数据的文件?
- ruby-on-rails - 如何在 searchkick 中编写布尔查询
- javascript - 试图找到一种用我的数组显示图像的方法
- android - PlatformException(ERROR_INVALID_EMAIL, 电子邮件地址格式错误。, null)
- c# - c# foreach loop list with class
- javascript - Bootstrap 需要 popper.js,尽管在 Angular 中使用 bundle