python - 在生成无限数据时,PyTorch 的 __len__ 应该是什么?
问题描述
假设我正在尝试使用 PyTorch 来学习方程式y = 2x
,并且我想生成无限量的数据来训练我的模型。我应该提供一个__len__
功能。下面是一个例子。在这种情况下应该是什么?如何指定每个时期的小批量迭代次数?我只是随意设置一个数字吗?
import numpy as np
from torch.utils.data import Dataset
class GenerateUnlimitedData(Dataset):
def __init__(self):
pass
def __getitem__(self, index):
x = np.random.randint(1,10)
y = 2 * x
return x, y
def __len__(self):
return 1000000 # This works but is not correct
解决方案
您应该使用torch.utils.data.IterableDataset
而不是torch.utils.data.Dataset
. 在您的情况下,它将是:
import torch
class Dataset(torch.utils.data.IterableDataset):
def __init__(self, batch_size):
super().__init__()
self.batch_size = batch_size
def __iter__(self):
while True:
x = torch.randint(1, 10, (self.batch_size,))
y = 2 * x
yield x, y
您应该使用批处理(可能是大批处理),因为这会加快计算速度(pytorch 非常适合一次对多个样本进行 GPU 计算)。
推荐阅读
- css - 输入组边框在中间加倍
- android - Google Maps GeoJSON Utility 可以使用通过 geojson.io 设置的属性吗?
- reactjs - React.js,在 DOM 中迭代的正确方法
- python - 用下标数据刮表
- azure - angular 6 VSTS azure Build ng build configuration=production not working
- python - Python - 对比 2 列表
- bash - zsh 以编程方式填充命令内容
- java - Mongodb java驱动错误
- gradle - 从存储库下载并执行 jar 作为 Gradle 构建的第一步
- python - 两个不同形状的张量相减的结果是什么意思?