首页 > 解决方案 > Pytorch iter() 无休止地运行或抛出 RecursionError

问题描述

我遇到了 Pytorch DataLoader 的问题。每当我尝试通过 iter() 函数加载下一批时,该函数都会无限期地运行。我也尝试在 Google Colab 中运行该函数,它返回一个 RecursionError。这是迭代函数:

def create_data_loader(df, tokenizer, max_len, batch_size):
    ds = PostTitleDataset(
        title = df["clean_title"].to_numpy(),
        label = df["6_way_label"].to_numpy(),
        tokenizer = tokenizer,
        max_len = max_len
    )
    
    return DataLoader(
        ds,
        batch_size = batch_size,
        num_workers = 0
    )

BATCH_SIZE = 16
MAX_LEN = 80

test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)

**# This function runs indefinitely or gives me a RecursionError in Google Colab**
if __name__ == '__main__':
  data = next(iter(test_data_loader))
  data.keys()

这是来自 Google Colab 的错误消息。在 Jupyter Notebook 中,它只是无限期地运行而没有错误消息:

RecursionErrorTraceback(最近一次调用最后一次) in () ----> 1 data = next(iter(test_data_loader)) 2 data.keys()

5 帧...最后 1 帧重复,从下面的帧...

getitem (self, item) 24 #batch = convert_to_batch(dataframe) 25 ---> 26 title = self["clean_title"][item] 27 label = self["6_way_label"][item] 28

RecursionError:超出最大递归深度

有谁知道如何解决这个问题以及如何让 iter() 函数运行成功返回一个批次?

只是为了提供完整的信息,在下面你会找到我的自定义数据集类:

class PostTitleDataset(Dataset):
    
    def __init__(self, title, label, tokenizer, max_len):
        self.title = title,
        self.label = label,
        self. tokenizer = tokenizer,
        self.max_len = max_len
        
    def __len__(self):
        return len(self.title)
    
    def __getitem__(self, item):
        
        #batch = convert_to_batch(dataframe)
        
        title = self["clean_title"][item]
        label = self["6_way_label"][item]
        
        # Encode text input content
        encoding = self.tokenizer(
            title,
            padding=True,
            truncation=True,
            add_special_tokens=True,
            return_token_type_ids=False,
            return_attention_mask=True,
            return_tensors="pt",
        )
        
        # Return first post title of batch as validation
        return {
            "clean title": title,
            "input_ids": encoding["input_ids"].flatten(),
            "attention_mask": encoding["attention_mask"].flatten(),
            "Label": torch.tensor(label, dtype=torch.long)
        }

标签: pythonpytorchiteratorpytorch-dataloader

解决方案


推荐阅读