python - 在 keras ImageDataGenerator.flow() 中使用多输出标签并使用 model.fit_generator()
问题描述
我有一个单输入多输出神经网络模型,其最后一层是
out1 = Dense(168, activation = 'softmax')(dense)
out2 = Dense(11, activation = 'softmax')(dense)
out3 = Dense(7, activation = 'softmax')(dense)
model = Model(inputs=inputs, outputs=[out1,out2,out3])
每个图像的 Y 标签如下
train
>>
image_id class_1 class_2 class_3
0 Train_0 15 9 5
1 Train_1 159 0 0
...
...
...
453651 Train_453651 0 15 34
453652 Train_453652 18 0 7
编辑:-
train.iloc[:,1:4].nunique()
>>
class_1 168
class_2 11
class_3 7
dtype: int64
所以看看这些不同范围的类,我应该使用categorical_crossentropy
orsparse_categorical_crossentropy
吗?以及我应该如何Y_labels
为下面给出的代码使用流入?
imgs_arr = df.iloc[:,1:].values.reshape(df.shape[0],137,236,1)
# 32332 columns representing pixels of 137*236 and single channel images.
# converting it to (samples,w,h,c) format
Y = train.iloc[:,1:].values #need help from here
image_data_gen = ImageDataGenerator(validation_split=0.25)
train_gen = image_data_gen.flow(x=imgs_arr, y=Y, batch_size=32,subset='training')
valid_gen = image_data_gen.flow(x=imgs_arr,y=Y,subset='validation')
这是通过Y
或使用Y=[y1,y2,y3]
where的正确方法吗
y1=train.iloc[:,1].values
y2=train.iloc[:,2].values
y3=train.iloc[:,3].values
解决方案
哎哟....
根据您给出的消息flow
,您将需要一个输出。因此,您需要在模型内部进行分离。(Keras 没有遵循自己的标准)
这意味着类似:
Y = train.iloc[:,1:].values #shape = (50210, 3)
使用单个输出,例如:
out = Dense(168+11+7, activation='linear')(dense)
还有一个处理分离的损失函数:
def custom_loss(y_true, y_pred):
true1 = y_true[:,0:1]
true2 = y_true[:,1:2]
true3 = y_true[:,2:3]
out1 = y_pred[:,0:168]
out2 = y_pred[:,168:168+11]
out3 = y_pred[:,168+11:]
out1 = K.softmax(out1, axis=-1)
out2 = K.softmax(out2, axis=-1)
out3 = K.softmax(out3, axis=-1)
loss1 = K.sparse_categorical_crossentropy(true1, out1, from_logits=False, axis=-1)
loss2 = K.sparse_categorical_crossentropy(true2, out2, from_logits=False, axis=-1)
loss3 = K.sparse_categorical_crossentropy(true3, out3, from_logits=False, axis=-1)
return loss1+loss2+loss3
用 编译模型loss=custom_loss
。
那么flow
当你这样做时应该停止抱怨flow
。
只需确保 X 和 Y 的顺序完全相同:正确imgs_arr[i]
对应Y[i]
。
推荐阅读
- node.js - 在nodejs express中上传多个文件,将目标名称作为新的对象ID
- javascript - Javascript 滚动顶部
- reactjs - 将自定义“全部”(总计)行添加到可像任何其他行一样选择的 ag 网格的顶部
- r - R闪亮的渲染表 - 如何更改列宽以包装特定列的文本?
- python - 是否可以从其他笔记本复制和粘贴 Jupyter Notebook 单元格并让新笔记本能够调用声明的变量?
- flutter - Mi android manifest.xml no reconoce los permisos
- kubernetes - 如何强制所有 kubernetes 服务(代理、kublet、apiserver...、容器)将日志写入 /var/logs
- marklogic - Forest Meters 启动时出错:XDMP-BADSTARTUPTOKEN:Forest Meters 的启动令牌错误
- apache - Xampp Apache 无法在 Win 10 上启动
- java - 自 Java 9 以来更改了 Swing TitledBorder 的外观