首页 > 解决方案 > model.predict 类与数据集类不匹配

问题描述

我正在使用 Keras 编写一个 CNN 分类器,它应该将一组 40k 多张路标图片分类为 43 个类别中的一个。一切都很好,直到我试图找出模型在对看不见的数据进行分类时犯了哪些错误。输出文件中的类似乎与数据集中的类不匹配,我不知道如何确定哪个类是哪个。这个问题在问题的最后得到了更好的解释。

批处理大小为 64。输出文件非常大,但其结构如下:

[[3.81430182e-05 3.55855487e-02 3.77756208e-02 ... 3.93179851e-03 4.57952236e-04 1.19631949e-07]
[2.46175125e-09 8.71188703e-08 9.04489157e-12 ... 7.63094476e-08 2.24849509e-06 9.93708588e-13]
...
[1.31991830e-13 1.99924495e-12 7.65954244e-10 ... 1.51650678e-13 1.77550303e-14 9.25261628e-16]]
-
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
...
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

这是一个批次的输出,总共有 198 个这样的批次。首先有 64 行,每行 43 个值代表神经网络的输出。然后有 64 行,每行有 43 个值,表示哪个类是正确的分类。

在测试集中,类由文件夹结构表示,如下所示:

Test_New/0
    00245.png
    00252.png
    00403.png
   ...
Test_New/1
    00001.png
    00024.png
    00076.png
   ...
...
Test_New/42
    00315.png
    00507.png
    00755.png
    ...

问题是,文件中的类与输出文件中的类不匹配!换句话说,我希望在输出文件中这样:

[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

这意味着该特定图像的正确分类是第三类,因为 1 在第三个位置。但这种情况并非如此。我怎么知道?因为我知道在代表第三类的“Test_New/2”文件夹中正好有 750 个文件,但是当我使用 notepad++ 中的 find 函数查找

[0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

行,它返回一个数字 660。这意味着文件中有 660 个该行的实例,这意味着它不能代表第三类。事实上,它代表第 11 类,因为它是唯一一个包含这么多文件的类。如果所有文件夹都有不同数量的文件,这将不是问题,但不幸的是其中一些共享相同数量的文件。

我的问题是为什么输出文件中的输出类被打乱了,我该如何解决这个问题?我怎么知道哪个班级是哪个班级?如果您不知道,您是否知道是否有其他方法可以知道哪些图像被错误分类?请帮忙,我在过去 3 个小时左右一直在拔头发。不好意思,代码这么多,就是不知道哪里出错了。谢谢!

标签: pythontensorflowkerasconv-neural-network

解决方案


在您的测试和验证生成器中设置 shuffle-False。在 model.fit 中不要指定 steps_per_epoch 或 validation_steps 让 model.fit 在内部确定这些值。现在您必须记住的一件事是,像 flow_from_directory 这样的 python 函数按字母数字顺序处理文件名。因此,例如,如果您在标有 1.jpg, 2.jpg ----9.jpg, 10,jpg 的目录中有文件 ---- 处理文件的顺序是 1.jpg, 10.jpg, 11 .jpg-----19.jpg,2.jpg ---。因此,如果您希望订单是严格的数字,则不是。下面是一个函数的代码,它将检测错误分类的测试文件并打印出所有错误分类的测试文件的文件名、真实类、预测类和预测概率。

def print_info( test_dir, test_gen, preds, print_code ):
    # test_dir is the full path to the directory containing the test images
    # test_gen is the name of your test generator
    # preds are the prediction from preds=model.predict
    # print code is an integer specifying the maximum number of error files you want to print out
    class_dict=test_gen.class_indices
    labels= test_gen.labels
    file_names= test_gen.filenames 
    error_list=[]
    true_class=[]
    pred_class=[]
    prob_list=[]
    new_dict={}
    error_indices=[]
    y_pred=[]
    for key,value in class_dict.items():
        new_dict[value]=key             # dictionary {integer of class number: string of class name}
    classes=list(new_dict.values())     # list of string of class names
    errors=0    
    for i, p in enumerate(preds):
        pred_index=np.argmax(p)
        true_index=labels[i]  # labels are integer values
        if pred_index != true_index: # a misclassification has occurred
            error_list.append(file_names[i])
            true_class.append(new_dict[true_index])
            pred_class.append(new_dict[pred_index])
            prob_list.append(p[pred_index])
            error_indices.append(true_index)            
            errors=errors + 1
        y_pred.append(pred_index)   
    if print_code !=0:
        if errors>0:
            if print_code>errors:
                r=errors
            else:
                r=print_code           
            msg='{0:^28s}{1:^28s}{2:^28s}{3:^16s}'.format('Filename', 'Predicted Class' , 'True Class', 'Probability')
            print(msg)
            for i in range(r):
                msg='{0:^28s}{1:^28s}{2:^28s}{3:4s}{4:^6.4f}'.format(error_list[i], pred_class[i],true_class[i], ' ', prob_list[i])
                print(msg)
                            
        else:
            msg='With accuracy of 100 % there are no errors to print'
            print(msg)

推荐阅读