python - 当我的输出标签在 Tensorflow 中是偶数时,如何制作大小为 5 的输出层?
问题描述
我的训练数据中有标签 0、2、4、6、8。所以,从技术上讲,我有 5 个类,但是当我编写以下代码时,
model = keras.Sequential([
layers.Dense(units=512, activation="relu",input_shape=[784]),
layers.Dense(units=512, activation="relu"),
layers.Dense(units=5,activation="softmax")
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
)
history = model.fit(
training_X, training_Y.reshape(2465,1),
validation_data=(val_X, val_Y.reshape(986,1)),
epochs=50,
verbose = 1
)
我得到这个错误,
InvalidArgumentError: Received a label value of 8 which is outside the valid range of [0, 5). Label values: 6 8 4 6 2 4 8 0 2 2 4 6 0 2 6 4 4 2 2 8 0 0 6 0 2 8 0 2 2 6 4 4
那么,如何只使用 5 个输出单元并针对这些标签进行训练呢?
解决方案
作为您收到的错误,您的整数标签应该更像0, 1, 2, 3, 4
而不是0,2,4,6,8
. 您可以转换标签来解决问题,也可以将标签转换为热编码向量,如下所示。
import numpy as np
import pandas as pd
x = np.random.randint(0, 256, size=(5, 784)).astype("float32")
y = pd.get_dummies(np.array([0, 2, 4, 6, 8])).values
x.shape, y.shape
((5, 784), (5, 5))
y
array([[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1]], dtype=uint8)
此外,您还需要使用categorical_crossentropy
损失函数而不是sparse_categorical_crossentropy
. 完整的工作代码:
model = keras.Sequential([
layers.Dense(units=512, activation="relu",input_shape=[784]),
layers.Dense(units=512, activation="relu"),
layers.Dense(units=5,activation="softmax")
])
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'],
)
history = model.fit(
x, y,
epochs=3, verbose=2
)
Epoch 1/3
380ms/step - loss: 153.8635 - accuracy: 0.2000
Epoch 2/3
16ms/step - loss: 194.0642 - accuracy: 0.6000
Epoch 3/3
16ms/step - loss: 259.9468 - accuracy: 0.6000
推荐阅读
- cakephp - Cakephp:encrypt() 的密钥无效,设置 cookie 时密钥必须至少为 256 位(32 字节)长
- postgresql - 循环遍历 Postgresql 中的列
- pandas - 将张量流数据集转换为熊猫数据框
- javascript - 尝试在 JS 中加载 JSON 样式数据
- c++ - 通过 boost 和 c++ 进行 smtp 身份验证
- ruby-on-rails - 如何在没有身份验证的情况下解析 ruby 中的 uri
- python - Python 3 中 configparser 中的 read 和 read_file 有什么区别?
- python - 根据其他键值对填充空白字典值
- python - Google Sheet Api 写入错误
- sql - NamedEntityGraph 仍然导致延迟加载异常