python - 在 PyTorch 中实现“无限循环”数据集和数据加载器
问题描述
我想实现一个无限循环 Dataset & DataLoader。这是我尝试过的:
class Infinite(Dataset):
def __len__(self):
return HPARAMS.batch_size
# return 1<<30 # This causes huge memory usage.
def __getitem__(self, idx):
"""Randomly generates one new example."""
return sample_func_to_be_parallelized()
infinite_loader = DataLoader(
dataset=Infinite(),
batch_size=HPARAMS.batch_size,
num_workers=16,
worker_init_fn=lambda worker_id: np.random.seed(worker_id),
)
while True:
for idx, data in enumerate(infinite_loader):
# forward + backward on "data"
如您所见,这里的主要挑战是__len()__
方法。如果我在那里放了一个足够大的数字,比如 1<<30,则症状是内存使用量将在训练循环的第一次迭代中跳转到 10+GB。一段时间后,可能是由于 OOM 导致工人死亡。
如果我在那里放一个小数字,比如 1 或 BATCH_SIZE,训练循环中的采样“数据”将定期复制。这不是我想要的,因为我希望在每次迭代时生成和训练新数据。
我猜过度内存使用的罪魁祸首是在堆栈中的某个地方,一堆东西被缓存了。随便看看 Python 方面的东西,我无法确定在哪里。
有人可以建议什么是实现我想要的最佳方式?(使用 DataLoader 的并行加载,同时保证加载的每个批次都是全新的。)
解决方案
这似乎在不定期复制数据的情况下工作:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
BATCH_SIZE = 2
class Infinite(Dataset):
def __len__(self):
return BATCH_SIZE
def __getitem__(self, idx):
return torch.randint(0, 10, (3,))
data_loader = DataLoader(Infinite(), batch_size=BATCH_SIZE, num_workers=16)
batch_count = 0
while True:
batch_count += 1
print(f'Batch {batch_count}:')
data = next(iter(data_loader))
print(data)
# forward + backward on "data"
if batch_count == 5:
break
结果:
Batch 1:
tensor([[4, 7, 7],
[0, 8, 0]])
Batch 2:
tensor([[6, 8, 6],
[2, 6, 7]])
Batch 3:
tensor([[6, 6, 2],
[8, 7, 0]])
Batch 4:
tensor([[9, 4, 8],
[2, 4, 1]])
Batch 5:
tensor([[9, 6, 1],
[2, 7, 5]])
所以我认为问题出在你的功能sample_func_to_be_parallelized()
上。
编辑:如果不是torch.randint(0, 10, (3,))
我使用np.random.randint(10, size=3)
in __getitem__
(作为 的示例sample_func_to_be_parallelized()
),那么数据确实在每批中重复。看到这个问题。
所以如果你在你的某个地方使用 numpy 的 RGN sample_func_to_be_parallelized()
,那么解决方法是使用
worker_init_fn=lambda worker_id: np.random.seed(np.random.get_state()[1][0] + worker_id)
并np.random.seed()
在每次调用data = next(iter(data_loader))
.
推荐阅读
- javascript - Ajax 请求多次提交输入
- kubernetes - 使用法兰绒网络插件无法从容器内访问互联网
- javascript - 对于给定的 4 个条目,2.9 是 14.5,得分为 20 是否正确
- c# - 如何创建具有多个参数的线程?
- c# - C#如何将自定义字段添加到getstream.io中的活动?
- java - Java堆分解总和不等于总堆容量
- r - 当我在 R Studio 中运行我的 R 时,我收到有关规范化路径的消息。我如何让它消失
- angularjs - 如何传递过滤器参数
服务门户 - vb.net - 将 BigInteger 向左移动后出现意外值
- c++ - 将结构数据类型传递给 C++ 中的命名管道