首页 > 解决方案 > 从多个不同大小的数据集加载 PyTorch 数据

问题描述

我有多个数据集,每个数据集都有不同数量的图像(和不同的图像尺寸)。在训练循环中,我想从所有数据集中随机加载一批图像,但每个批次只包含来自单个数据集的图像。例如,我有数据集 A、B、C、D,每个都有图像 01.jpg、02.jpg、... n.jpg(其中 n 取决于数据集),假设批量大小为 3。在第一个加载批次,例如,我可能会得到图像[B/02.jpg,B/06.jpg,B/12.jpg],在下一批[D/01.jpg,D/05.jpg,D/12 .jpg]等

到目前为止,我已经考虑了以下几点:

  1. 对每个数据集使用不同的DataLoader,例如dataloaderA、dataloaderB等,然后在每个训练循环中随机选择一个dataloader并从中获取一批。但是,这将需要一个 for 循环,并且对于大量数据集,它会非常慢,因为它不能在工作人员之间拆分以并行执行。
  2. 将单个 DataLoader 与来自所有数据集的所有图像一起使用,但使用自定义 collat​​e_fn 将仅使用来自同一数据集的图像创建批处理。(我不确定该怎么做。)
  3. 我查看了 ConcatDataset 类,但从它的源代码来看,如果我使用它并尝试获取一个新批次,其中的图像将从我不想要的不同数据集中混合。

最好的方法是什么?谢谢!

标签: pythonpytorch

解决方案


您可以使用ConcatDataset,并提供一个batch_samplerto DataLoader

concat_dataset = ConcatDataset((dataset1, dataset2))

ConcatDataset.comulative_sizes将为您提供您拥有的每个数据集之间的界限:

ds_indices = concat_dataset.cumulative_sizes

现在,您可以使用ds_indices创建批处理采样器。请参阅来源以BatchSampler供参考。您的批处理采样器只需要返回一个包含 N 个随机索引的列表,这些索引将尊重ds_indices边界。这将保证您的批次将具有来自同一数据集的元素。


推荐阅读