首页 > 解决方案 > Pytorch DataLoader 多数据源

问题描述

我正在尝试使用 Pytorch 数据加载器来定义我自己的数据集,但我不确定如何加载多个数据源:

我当前的代码:

class MultipleSourceDataSet(Dataset):
    def __init__ (self, json_file, root_dir, transform = None):
        with open(root_dir + 'block0.json') as f:
            self.result = torch.Tensor(json.load(f))

    self.root_dir = root_dir
    self.transform = transform

    def __len__(self):
        return len(self.result[0])

    def __getitem__ (self):
        None

数据源为50块下root_dir = ~/Documents/blocks/

我将它们拆分并避免之前直接组合它们,因为这是一个非常大的数据集。

如何将它们加载到单个数据加载器中?

标签: python-3.ximage-processingmachine-learningdeep-learningpytorch

解决方案


因为DataLoader你需要一个Dataset,你的问题是你有多个'json'文件,你只知道如何分别创建Dataset一个'json'
在这种情况下,您可以做的是使用ConcatDataset包含'json'您创建的所有单数据集:

import os
import torch.utils.data as data

class SingeJsonDataset(data.Dataset):
    # implement a single json dataset here...

list_of_datasets = []
for j in os.path.listdir(root_dir):
    if not j.endswith('.json'):
        continue  # skip non-json files
    list_of_datasets.append(SingeJsonDataset(json_file=j, root_dir=root_dir, transform=None))
# once all single json datasets are created you can concat them into a single one:
multiple_json_dataset = data.ConcatDataset(list_of_datasets)

现在您可以将连接的数据集输入到data.DataLoader.


推荐阅读