python - Keras多输出numpy转换
问题描述
我有一个y_train
形状为(samples, 9)
. 我的X_train
输入是(samples, 30, 1)
.
这些贯穿一个构建为的模型:
def create_model(input_shape, outputs):
i = Input(shape=input_shape)
x = Dense(256, activation="relu")(i)
x = Dropout(0.5)(x)
x = Dense(128, activation="relu")(x)
x = Dropout(0.5)(x)
x = Dense(64, activation="relu")(x)
x = Dropout(0.5)(x)
x = Flatten()(x)
# Optimize each binary output independently.
o = list(map(lambda _: Dense(1, activation='sigmoid')(x), range(outputs)))
m = Model(i, o)
m.compile('adam', loss='binary_crossentropy', metrics=['accuracy'])
return m
model = create_model((30, 1), 9)
这会产生训练错误:
检查模型目标时出错:您传递给模型的 Numpy 数组列表不是模型预期的大小。预计会看到 9 个数组,但得到了以下 1 个数组的列表:
[array([[1., 1., 1., ..., 0., 0., 0.],
[1., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 1., 1., 1.],
[0., 0., 0., ..., 0., 0., 0....
我尝试了y_train
形状的变化,使用(9, samples, 1)
和(samples, 1, 9)
。keras 想看到我的(samples, 9)
数组形状如何变化?
解决方案
您的模型有 9 个输出层,每个都有二进制交叉熵。因此,您需要将输出作为 9 个输出的列表传递,其中每个输出都是一个(samples, 1)
大小合适的数组,而不是具有 9 列的单个数组。
因此,您需要执行以下操作。
# Assuming your y_train is of size (samples, 9)
y_train_list = np.split(y_train, y_train.shape[1], axis=1)
model.fit(x_train, y_train_list)
这是一个带有玩具数据的工作示例
x_train = np.random.normal(size=(500,30,1))
y_train = np.random.choice([0,1], size=(500, 9))
y_train_list = np.split(y_train, y_train.shape[1], axis=1)
model.fit(X_tr, y_train_list)
使用 train_test_split 创建训练验证数据
from sklearn.model_selection import train_test_split
tr_x, ts_x, tr_y, ts_y =train_test_split(X_tr, Y_tr, test_size=0.33)
tr_list_y = np.split(tr_y, tr_y.shape[1], axis=1)
ts_list_y = np.split(ts_y, ts_y.shape[1], axis=1)
model.fit(tr_x, tr_list_y, validation_data=(ts_x, ts_list_y))
推荐阅读
- mysql - 求和并仅标记最接近 MySQL 中控制值的值
- python - 熊猫根据条件从列值创建列表
- node.js - nodemon 应用程序崩溃 - 在开始之前等待文件更改......如何修复它?
- databricks - 有没有一种方法可以在 Databricks 中描述 Key Vault 支持的范围,以了解它指向哪个 Key Vault?
- react-native - nsnumber 类型的 json 值“1”无法转换为 uiedgeinsets
- django - Django+gunicorn+nginx上传大文件连接重置错误
- angular8 - 如何在Angular 8上的可点击项目上使用调用函数?
- docker-compose - 将日志配置添加到 docker-compose 后磁盘操作量增加
- python - 如何遍历列表并更改其中的值?
- python - Python:如何以两个参数作为参数传递函数