首页 > 解决方案 > 使用 flow_from_directory() 时如何知道标签的顺序?

问题描述

我正在使用 softmax 进行多类分类。我的最终输出给了我三个概率,每个概率对应一个类,我怎么知道每个概率指的是哪个类?我正在通过图像数据生成器训练我的模型。

train_generator = train_datagen.flow_from_directory(
    '/gdrive/MyDrive/shot/training',
    target_size=(640, 360),
    batch_size=32,
    class_mode='categorical')

/gdrive/MyDrive/shot/training
在同一个项目中,我通过 test_datagen 提供一些图像进行预测,我如何知道 model.predict() 正在处理哪个图像?

t_gen = test_datagen.flow_from_directory(
    '/gdrive/MyDrive/shot/testing',
    target_size=(640, 360),
    batch_size=32,
    class_mode='categorical')

classes = model.predict(t_gen,batch_size=32)
print(classes)

我得到一个充满概率的表格作为答案,但我不知道预测了哪个图像以及每个概率对应于哪个标签。

标签: pythontensorflow

解决方案


由于您是从图像数据生成器训练的,因此您可以按如下方式使用您的课程

labels = list(train_datagen.class_indices.keys())
pred = model.predict(input_data)
classes = labels[pred.argmax()]

print(classes)

# or
labels = ['badshot', 'goodshot', 'noshot']
pred = model.predict(input_data)
classes = labels[pred.argmax()]

推荐阅读