首页 > 解决方案 > 如何从数据帧在 keras 流中提供一个热编码矢量数据帧

问题描述

我有一个 multi_class 问题而不是 multi_label 问题,并且我有一个如下的数据框。
数据框

我希望它在 flow_from_dataframe 中使用

train_generator=train_data_gen.flow_from_dataframe(train_df,directory='directory',
                                                      target_size=(img_shape,img_shape),
                                                      x_col="image_id",
                                                      y_col=['healthy','multiple_diseases','rust','scab'],
                                                      class_mode='categorical',
                                                      shuffle=False,
                                                       subset='training',
                                                      batch_size=batch_size)

我收到以下错误

TypeError: If class_mode="categorical", y_col="['healthy', 'multiple_diseases', 'rust', 'scab']" column values must be type string, list or tuple.

标签: pythontensorflowkerasscikit-learn

解决方案


使用class_mode = "raw"以便所有 4 个类都加载有二进制标签。

有关如何修改标签以及使用 class_mode 进行多类分类的各种方法的信息,我推荐这篇文章。


推荐阅读