python-3.x - 我最后一个密集的 keras 层有什么问题?
问题描述
我正在研究 keras 中的一个小型 NN,用于解决多类分类问题。我有 9 个不同的标签,我的特征也是 9 个。
我的火车/测试形状如下:
Sets shape:
x_train shape: (7079, 9)
y_train shape: (7079,)
x_test shape: (7079, 9)
y_test shape: (7079,)
但是当我试图让它们分类时:
y_train = tf.keras.utils.to_categorical(y_train, num_classes=9)
y_test = tf.keras.utils.to_categorical(y_test, num_classes=9)
我收到以下错误:
IndexError: index 9 is out of bounds for axis 1 with size 9
这里有更多关于y_train
print(np.unique(y_train)) # [1. 2. 3. 4. 5. 6. 7. 8. 9.]
print(len(np.unique(y_train))) # 9
任何人都会知道问题是什么?
解决方案
的形状y_train
是1D
。您必须对其进行一次热编码。就像是
y_train = tf.keras.utils.to_categorical(y_train , num_classes=9)
也是y_test
如此。
更新
根据文档,
tf.keras.utils.to_categorical(y, num_classes=None, dtype="float32")
这里,y:要转换为矩阵的类向量(从0
到的整数num_classes
)。与您的情况一样,y_train
类似于[1,2,..]
. 您需要执行以下操作:
y_train = tf.keras.utils.to_categorical(y_train - 1, num_classes=9)
这是一个供参考的例子。如果我们这样做
class_vector = np.array([1, 1, 2, 3, 5, 1, 4, 2])
print(class_vector)
output_matrix = tf.keras.utils.to_categorical(class_vector,
num_classes = 5, dtype ="float32")
print(output_matrix)
[1 1 2 3 5 1 4 2]
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-15-69c8be7a0f1a> in <module>()
6 print(class_vector)
7
----> 8 output_matrix = tf.keras.utils.to_categorical(class_vector, num_classes = 5, dtype ="float32")
9 print(output_matrix)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/utils/np_utils.py in to_categorical(y, num_classes, dtype)
76 n = y.shape[0]
77 categorical = np.zeros((n, num_classes), dtype=dtype)
---> 78 categorical[np.arange(n), y] = 1
79 output_shape = input_shape + (num_classes,)
80 categorical = np.reshape(categorical, output_shape)
IndexError: index 5 is out of bounds for axis 1 with size 5
为了解决这个问题,我们将数据转换为从零开始的格式。
output_matrix = tf.keras.utils.to_categorical(class_vector - 1,
num_classes = 5, dtype ="float32")
print(output_matrix)
[[1. 0. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 0. 1. 0. 0.]
[0. 0. 0. 0. 1.]
[1. 0. 0. 0. 0.]
[0. 0. 0. 1. 0.]
[0. 1. 0. 0. 0.]]
推荐阅读
- swift - 在 Xcode 模拟器上运行预览
- terraform - 如何从 terraform 列表中选择元素
- c - 需要在多个数组中显示前 3 个最高数字
- android - 成功登录后如何使“登录布局”更改布局
- javascript - 路由的组件必须是 React native expo 中的 react 组件错误
- dataweave - 将带有父 ID 和子数组的 JSON 转换为 Dataweave 中子数组中每个元素一个 json 对象的数组
- java - 使用鼠标和键盘在 3D 空间中移动 [JavaFX3D]
- java - 如何根据前 10 位数字查找重复值?
- ios - 如何在应用启动时监控订阅状态(swift)
- android - 在断线后从 TextView 中删除无用的空间