python - 仅在 K-folds 交叉验证中增加训练集
问题描述
我正在尝试为不平衡的数据集(0 类 = 4000 个图像,1 类 = 大约 250 个图像)创建一个二进制 CNN 分类器,我想对其执行 5 倍交叉验证。目前,我正在将训练集加载到 ImageLoader 中,该 ImageLoader 应用我的转换/增强(?)并将其加载到 DataLoader 中。但是,这会导致我的训练拆分和验证拆分都包含增强数据。
我最初应用离线转换(离线增强?)来平衡我的数据集,但是从这个线程(https://stats.stackexchange.com/questions/175504/how-to-do-data-augmentation-and-train-validate- split),似乎只增加训练集是理想的。我还希望在仅增强训练数据上训练我的模型,然后在 5 折交叉验证中在非增强数据上对其进行验证
我的数据被组织为根/标签/图像,其中有 2 个标签文件夹(0 和 1)和图像分类到各自的标签中。
到目前为止我的代码
total_set = datasets.ImageFolder(ROOT, transform = data_transforms['my_transforms'])
//Eventually I plan to run cross-validation as such:
splits = KFold(cv = 5, shuffle = True, random_state = 42)
for train_idx, valid_idx in splits.split(total_set):
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(total_set, batch_size=32, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(total_set, batch_size=32, sampler=valid_sampler)
model.train()
//Model train/eval works but may be overpredict
我确定我在这段代码中做的不是最优或错误的,但我似乎找不到任何关于专门增加交叉验证中的训练拆分的文档!
任何帮助,将不胜感激!
解决方案
一种方法是实现一个包装数据集类,该类将转换应用于 ImageFolder 数据集的输出。例如
class WrapperDataset:
def __init__(self, dataset, transform=None, target_transform=None):
self.dataset = dataset
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
image, label = self.dataset[index]
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
label = self.target_transform(label)
return image, label
def __len__(self):
return len(self.dataset)
然后,您可以通过使用不同的转换包装更大的数据集,在您的代码中使用它。
total_set = datasets.ImageFolder(ROOT)
# Eventually I plan to run cross-validation as such:
splits = KFold(cv = 5, shuffle = True, random_state = 42)
for train_idx, valid_idx in splits.split(total_set):
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(
WrapperDataset(total_set, transform=data_transforms['train_transforms']),
batch_size=32, sampler=train_sampler)
valid_loader = torch.utils.data.DataLoader(
WrapperDataset(total_set, transform=data_transforms['valid_transforms']),
batch_size=32, sampler=valid_sampler)
# train/validate now
我没有测试过这段代码,因为我没有你的完整代码/模型,但概念应该很清楚。
推荐阅读
- android - 屏幕锁定时无法运行 ble 扫描仪
- javascript - 如何使用 preact 创建页面转换
- java - 为什么 java LONG_MAX 与 vue axios 数据不一样??如何解决?
- java - 从字节数组解码 Avro 文件
- python - Python:如果值不是预期的,那么发出警报的正确方法是什么?
- python - 无法在 Pyqt5 应用程序中导出 Folium 绘制多边形上的坐标
- ckan - CKAN 登录重定向到缺少 root_path 的 URL
- php - php / 正则表达式将多个段落转换为一个带有换行符的段落
- join - 在 Cakephp4 中加入查询只返回一个表数据
- javascript - 随机选项如何捕捉通讯号码?