首页 > 解决方案 > 使用 IterableDataset 加载巨大的自定义数据集

问题描述

我有一个巨大的数据集,其特征(input_id、input_mask、segment_id、label_id)以 64 个批次保存在一个 pickle 文件中。我阅读了这个文件,创建了一个 TensorDataset 并传递给数据加载器进行训练。由于特征文件太大而无法创建完整的 TensorDataset,我想将 TensorDataset 转换为 IterableDataset,以便一次从特征文件中检索一批样本并将其传递给数据加载器。但是在训练时,我收到以下错误: TypeError: iter() returned non-iterator of type 'TensorDataset'

以下是我编写的自定义数据集类:

class MyDataset(IterableDataset):

    def __init__(self,args):
        self.args=args
       
    def get_features(self,filename):
        with open(filename, "rb") as f:
            while True:
                try:
                    yield pickle.load(f)
                except EOFError:
                    break  
                    
    def process(self,args):
        if args.cached_features_file:
            cached_features_file = args.cached_features_file

        if os.path.exists(cached_features_file):
            features=self.get_features(cached_features_file)

        feat = next (features)
        li=list(feat)
        all_input_ids=torch.tensor([f.input_ids for f in li ], dtype=torch.long)
        all_input_mask= torch.tensor([f.input_mask for f in li ], dtype=torch.long)
        all_segment_ids= torch.tensor([f.segment_ids for f in li], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in li ], dtype=torch.long)
        
        dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        return dataset
      
    def __iter__(self):
        dataset=self.process(self.args)       
        return dataset

我像这样使用它:

train_dataset=MyDataset(args)
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)

我知道 TensorDataset 是需要索引的地图样式,而 IterableDataset 是可迭代样式,这是错误的原因。即使我返回特征张量的列表/元组而不是 TensorDataset,我也会收到类似的错误。有人可以告诉我如何使用 IterableDataset 以正确的方式加载批处理数据集吗?

标签: pythonpytorch

解决方案


我通过以不同的方式保存数据集解决了这个问题。我将这些特征作为字典对象保存在一个pickle文件中,然后一次读取一个,然后传递给数据加载器进行处理。批处理由数据加载器自动完成。这是自定义类现在的样子:

class MyDataset(IterableDataset):

    def __init__(self,filename):
     
        self.filename=filename
        super().__init__()
                    
    def process(self,filename):
        with open(filename, "rb") as f:
            while True:
                try:
                    yield pickle.load(f)
                except EOFError:
                    break

    def __iter__(self):
        dataset=self.process(self.filename)          
        return dataset

推荐阅读