首页 > 解决方案 > 在生成无限数据时,PyTorch 的 __len__ 应该是什么?

问题描述

假设我正在尝试使用 PyTorch 来学习方程式y = 2x,并且我想生成无限量的数据来训练我的模型。我应该提供一个__len__功能。下面是一个例子。在这种情况下应该是什么?如何指定每个时期的小批量迭代次数?我只是随意设置一个数字吗?

import numpy as np
from torch.utils.data import Dataset

class GenerateUnlimitedData(Dataset):
    def __init__(self):
        pass
    
    def __getitem__(self, index):
        x = np.random.randint(1,10)
        y = 2 * x
        return x, y
    
    def __len__(self):
        return 1000000 # This works but is not correct

标签: pythonpytorch

解决方案


您应该使用torch.utils.data.IterableDataset而不是torch.utils.data.Dataset. 在您的情况下,它将是:

import torch


class Dataset(torch.utils.data.IterableDataset):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def __iter__(self):
        while True:
            x = torch.randint(1, 10, (self.batch_size,))
            y = 2 * x
            yield x, y

您应该使用批处理(可能是大批处理),因为这会加快计算速度(pytorch 非常适合一次对多个样本进行 GPU 计算)。


推荐阅读