python - 具有数据生成器的多输入 keras 神经网络模型
问题描述
我有两张图片。一只用于左眼,一只用于右眼。我想在神经网络中一次喂它们。
from sklearn.model_selection import train_test_split
XL_train, XL_val, yL_train, yL_val = train_test_split(XL, y, test_size=0.33, random_state=42)
XR_train, XR_val, yR_train, yR_val = train_test_split(XR, y, test_size=0.33, random_state=42)
XL 包含左眼 (3500, 224, 224, 3)
图像 XR 包含右眼图像(3500, 224, 224, 3)
我已经创建了数据生成器并将我的图像转换如下
XR_generator = train_datagen.flow(XR_train, yR_train, batch_size=BATCH_SIZE)
XL_generator = train_datagen.flow(XL_train, yL_train, batch_size=BATCH_SIZE)
vR_generator = val_datagen.flow(XR_val, yR_val, batch_size=BATCH_SIZE)
vL_generator = val_datagen.flow(XL_val, yL_val, batch_size=BATCH_SIZE)
使用 resnet 作为模型
import keras
left_input=Input(shape=XL.shape[1::])
right_input=Input(shape=XR.shape[1::])
left_model = ResNet50(include_top=False,input_tensor=left_input)
for layer in left_model.layers:
layer.name = layer.name + '_left'
layer.trainable = True
right_model = ResNet50(include_top=False,input_tensor=right_input)
for layer in right_model.layers:
layer.name = layer.name + '_right'
layer.trainable = True
使用 resnet 附加的层
x = keras.layers.concatenate([left_model.output, right_model.output])
x= keras.layers.Flatten()(x)
x = keras.layers.Dropout(0.2)(x)
x = keras.layers.Dense(512)(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Activation('relu')(x)
x = keras.layers.Dense(512)(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Activation('relu')(x)
x = keras.layers.Dropout(0.2)(x)
out = keras.layers.Dense(8, activation='sigmoid')(x)
model = keras.models.Model(inputs=[left_input, right_input], outputs=out)
编译模型并拟合模型如下
history=model.fit_generator(generator=[XL_generator,XR_generator],
steps_per_epoch=steps_train,
validation_data=[vL_generator,vR_generator],
validation_steps=steps_valid,
epochs=10,
)
我收到以下错误。
AttributeError:“NumpyArrayIterator”对象没有属性“ndim”
更新
def multi_train_gen(gen,XR_train,XL_train,yR_train,yL_train):
XR_generator = train_datagen.flow(XR_train, yR_train, batch_size=BATCH_SIZE)
XL_generator = train_datagen.flow(XL_train, yL_train, batch_size=BATCH_SIZE)
while True:
X1i = XR_generator.next()
X2i = XL_generator.next()
yield [X1i[0], X2i[0]], X2i[1]
def multi_val_gen(gen,XR_val,XL_val,yR_val,yL_val):
vR_generator = val_datagen.flow(XR_val, yR_val, batch_size=BATCH_SIZE)
vL_generator = val_datagen.flow(XL_val, yL_val, batch_size=BATCH_SIZE)
while True:
X1i = vR_generator.next()
X2i = vL_generator.next()
yield [X1i[0], X2i[0]], X2i[1]
训练生成器
train_gen=multi_train_gen(train_datagen,XR_train,XL_train,yR_train,yL_train)
验证生成器
val_gen=multi_val_gen(val_datagen,XR_val,XL_val,yR_val,yL_val)
但问题不在于我无法访问课程。我希望我可以使用 scikit learn 的分类报告
解决方案
推荐阅读
- python - 查找具有不同标题的行
- user-interface - 一个非常简单的 JavaFX wait()、notify() 示例
- r - 在 R 中争吵数据帧,可能使用 dcast
- sapui5 - 导航到不同的 (JS) 视图 - 无法访问路由器
- javascript - Angular Material 表没有排序、分页和过滤(使用 Angular、Nodejs 和 Mysql)
- c++ - 从 x86 更改为 x64 后的链接器错误
- machine-learning - pytorch 中使用 torch.no_grad() 的范围
- node.js - node.js Osmosis 编译 js 后得到结果
- flutter - Flutter,多个浮动操作按钮(缺口)
- android - 确保文件不存在 friestore