python - 使用 Keras 的多类分类中的数组形状错误
问题描述
我有一个由一个输入(一个整数)和一个输出组成的数据集,成为一个标签,例如:
3042,0
3338,1
1162,3
1605,2
...
所以最后一列应该成为标签的单热编码(使用 Keras 的 to_categorical()),例如:
[[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]]
我的输入和输出形状是
X_data.shape: (2407060,)
y_data.shape: (2407060, 4)
但是,我收到一个错误,即我的输出应该具有形状 (1,) 而不是 (4,),即使我的最后一层有 4 个输出。
Using TensorFlow backend.
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 16) 32
_________________________________________________________________
...
_________________________________________________________________
dense_9 (Dense) (None, 4) 68
=================================================================
Total params: 63,156
Trainable params: 63,156
Non-trainable params: 0
_________________________________________________________________
ValueError: Error when checking target: expected dense_9 to have shape (1,) but got array with shape (4,)
这是代码:
model = Sequential()
model.add(Dense(16, activation='sigmoid', input_dim=1))
model.add(Dense(32, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(128, activation='relu'))
model.add(Dense(256, activation='relu'))
model.add(LeakyReLU())
model.add(Dense(64, activation='relu'))
model.add(Dense(32, activation='relu'))
model.add(Dense(16, activation='relu'))
model.add(Dense(4, activation='sigmoid'))
model.compile(optimizer='nadam',
loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.summary()
X_data = np.loadtxt(data_file, delimiter=',', usecols=(3))
y_data = to_categorical(np.loadtxt(data_file, delimiter=',', usecols=(7)))
model.fit(X_data, y_data, epochs=20, validation_split=0.3, verbose=1, callbacks=[cp_callback])
到底是怎么回事?
解决方案
如果输出是单热格式,则应将损失函数更改为 categorical_crossentropy。
[1,0,0] [0,1,0] [0,0,1]
如果目标是整数,则可以使用 sparse_categorical_crossentropy。
1、2、3
推荐阅读
- java - 有没有办法在不使用 Thread.sleep 的情况下对 ScheduledExecutorService.scheduleAtFixedRate 进行单元测试?
- macos - 颤振看不到 ios-deploy
- javascript - 为什么相对图像路径在我的 React 应用程序中不起作用?
- gnupg - GPG 无法连接到 S.gpg-agent:连接被拒绝
- python - MongoDB推送到嵌套数组或插入新文档
- ios - React Native 模块,找不到 Cocoapod 文件
- c# - 有没有办法从子转发器中选择一行并将其插入到 ASP.NET C# 中的父转发器行/字段中?
- r - Making a 3D cylinder out of a polygon
- postgresql - Postgres 全文搜索:短语运算符 (
) distance 正在寻找精确的距离匹配 - javascript - (HTML,CSS,JS)尝试添加可点击的下拉菜单,图标按钮