首页 > 技术文章 > DL踩坑:初尝过拟合

weiba180 2020-03-02 00:00 原文

初尝过拟合

猫狗大战数据集

   |-- train 
       |-- cat
           |-- 1.jpg
           |-- 2.jpg
           |-- ...
       |-- dog
           |-- 1.jpg
           |-- 2.jpg
           |-- ...
  • 这样就可以调用torchvision来快速生成训练集
all_data =  torchvision.datasets.ImageFolder(
    root='train',
    transform=train_transform
)
all_len = int(len(all_data))
train_len = int(0.8*all_len)
vaild_len = int(all_len-train_len)
train_data , vaild_data= torch.utils.data.random_split(all_data,[train_len,vaild_len])
train_set = torch.utils.data.DataLoader(
    train_data,
    batch_size=BTACH_SIZE,
    shuffle=True
)
vaild_set = torch.utils.data.DataLoader(
    test_data,
    batch_size=BTACH_SIZE,
    shuffle=False
)
  • 上面偷懒对整个train里的图片分成两个分类的文件夹,没有另开一个验证集文件夹,而是使用random_split来划分,本来是没什么问题的,但就是这样埋下了锅,后面乱搞了一下就GG了,导致模型严重过拟合还浑然不知。

模型的选择

  • 一开始直接简单粗暴来AlexNet,但是不知道什么问题,不仅跑的慢,训练了几个epoch后收敛得也慢(其实可能根本没有在收敛
  • 好吧那就先放放,换了一个简单的CNN,对就和上次训练MNIST那个那么简单,先训练几个epoch试下,开跑~嗯嗯嗯跑的快了许多,loss也在减少,验证集准确率也终于不是50上下了,虽然和训练集相比提高得有点慢慢
  • 然后我调大了epoch继续训练...

锅在哪里呢?

  • 训练到后面训练集准确率九十多,验证集也九十了,果断把测试机test跑一遍提交且沾沾自喜
  • 然后一看得分蒙了,和全选猫的得分差不多因为我真的提交过,很明显是严重的过拟合了
  • 锅出在哪呢,看了下代码原来是我一开始跑的时候保存了模型,然后后面跑的时候再加载继续训练。但由于多次运行main文件,而random_split是随机划分的,所以就等于我把整个训练数据都跑了个遍,再加上模型过于简单,没有加dropout等,过拟合了也不知道,还傻乎乎交了上去

更新

  • 又发现了一个问题,测试集的图片并非按照1234的顺序排的,而是按照1,10,100。。。的字典序排的,原因是imagfolder中对文件名采取了字典序排序,详看pytorch源码
def make_dataset(dir, class_to_idx, extensions):
    images = []
    # expanduser把path中包含的"~"和"~user"转换成用户目录
    # 主要还是在Linux之类的系统中使用,在不包含"~"和"~user"时
    # dir不变
    dir = os.path.expanduser(dir)
    # 排序后按顺序通过for循环dir路径下的所有文件名
    for target in sorted(os.listdir(dir)):
        # 将路径拼合
        d = os.path.join(dir, target)
        # 如果拼接后不是文件目录,则跳出这次循环
        if not os.path.isdir(d):
            continue
        # os.walk(d) 返回的fnames是当前d目录下所有的文件名
        # 注意:第一个for其实就只循环一次,返回的fnames 是一个数组
        for root, _, fnames in sorted(os.walk(d)):
            # 循环每一个文件名
            for fname in sorted(fnames):
                # 文件的后缀名是否符合给定
                if has_file_allowed_extension(fname, extensions):
                    # 组合路径
                    path = os.path.join(root, fname)
                    # 将组合后的路径和该文件位于哪一个序号的文件夹下的序号
                    # 组成元祖
                    item = (path, class_to_idx[target])
                    # 将其存入数组中
                    images.append(item)

    return images
  • 所以说调包也要熟悉才行

推荐阅读