python - 从多个不同大小的数据集加载 PyTorch 数据
问题描述
我有多个数据集,每个数据集都有不同数量的图像(和不同的图像尺寸)。在训练循环中,我想从所有数据集中随机加载一批图像,但每个批次只包含来自单个数据集的图像。例如,我有数据集 A、B、C、D,每个都有图像 01.jpg、02.jpg、... n.jpg(其中 n 取决于数据集),假设批量大小为 3。在第一个加载批次,例如,我可能会得到图像[B/02.jpg,B/06.jpg,B/12.jpg],在下一批[D/01.jpg,D/05.jpg,D/12 .jpg]等
到目前为止,我已经考虑了以下几点:
- 对每个数据集使用不同的DataLoader,例如dataloaderA、dataloaderB等,然后在每个训练循环中随机选择一个dataloader并从中获取一批。但是,这将需要一个 for 循环,并且对于大量数据集,它会非常慢,因为它不能在工作人员之间拆分以并行执行。
- 将单个 DataLoader 与来自所有数据集的所有图像一起使用,但使用自定义 collate_fn 将仅使用来自同一数据集的图像创建批处理。(我不确定该怎么做。)
- 我查看了 ConcatDataset 类,但从它的源代码来看,如果我使用它并尝试获取一个新批次,其中的图像将从我不想要的不同数据集中混合。
最好的方法是什么?谢谢!
解决方案
您可以使用ConcatDataset
,并提供一个batch_sampler
to DataLoader
。
concat_dataset = ConcatDataset((dataset1, dataset2))
ConcatDataset.comulative_sizes
将为您提供您拥有的每个数据集之间的界限:
ds_indices = concat_dataset.cumulative_sizes
现在,您可以使用ds_indices
创建批处理采样器。请参阅来源以BatchSampler
供参考。您的批处理采样器只需要返回一个包含 N 个随机索引的列表,这些索引将尊重ds_indices
边界。这将保证您的批次将具有来自同一数据集的元素。
推荐阅读
- robotframework - 我有一个用于输入文本框的特定 xpath,如何在使用机器人框架时访问这个 xpath?//*[@id="p_0"]/table/tbody/tr[2]/td[2]/input
- python - 如何从 MATLAB R2019a 调用 Python?
- mongodb - mongoimport/mongoexport 不保留哪些 MongoDB 类型?
- javascript - EPERM:不允许操作,取消链接 'C:\Users\**\node_modules\.node-sass.DELETE\vendor\win32-x64-57\binding.node'
- python - 如何修复 ubuntu 终端中的 'ProxyError('Cannot connect to proxy.', error('Tunnel connection failed: 407 Proxy Authentication Required',)'?
- python - 如何在脚本中修改 if-else 逻辑以检索 T-14 天的数据
- php - Laravel - 为代码获取用户
- c# - WPF切换按钮悬停动画问题
- google-sheets - 带有 2 条规则的 Google 表格中的条件格式
- ansible - 如何仅将 Ansible 的播放回顾保存到文件中?