首页 > 解决方案 > 使用文件夹结构在 Keras 中进行预测时如何获得正确的标签?

问题描述

由于我的数据集是一组图像,因此我使用文件夹结构来组织数据:

/train
    /class1
        img.jpg
        img.jpg
        ...
    /class2
        ...
/validation
    /class1
        ...
    /class2
        ...

只有两个类,所以我使用了二进制类模式,如下所示:

train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='binary')

validation_generator = test_datagen.flow_from_directory(
    validation_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='binary')

现在,当使用该predict()函数时,输出显然是一个介于 0 和 1 之间的数值。但是,我不知道哪个标签属于哪个值(0 或 1)。如何获得真正的标签(class1 或 class2)?

标签: pythontensorflowkerasdeep-learninggenerator

解决方案


要获取类名,只需调用:

generator.classes

推荐阅读