pytorch - Pytorch - 自定义 DataLoader 永远运行
问题描述
class TripletImageLoader(torch.utils.data.Dataset):
def __init__(self):
self.data = [0]*10000000
def __getitem__(self, index):
pid = os.getpid() % WORKER_SIZE
# My code here only uses pid, doesnt use index
return torch.tensor(batch.data), torch.tensor(batch.label)
def __len__(self):
return len(self.data)
我需要我的数据加载器永远运行。现在它总是在达到 10000000 或任何最大整数大小后终止。我如何让它永远运行,我不关心'索引'我没有使用它。我只是在使用这个类的工人能力
解决方案
由于您需要对同一批次进行多次迭代训练,因此以下代码框架应该适合您。
def train(args, data_loader):
for idx, ex in enumerate(data_loader):
# iterate over each mini-batches
# add your code
def validate(args, data_loader):
with torch.no_grad():
for idx, ex in enumerate(data_loader):
# iterate over each mini-batches
# add your code
# args = dict() containing required parameters
for epoch in range(start_epoch, args.num_epochs):
# train_loader = data loader for the training data
train(args, train_loader)
您可以按如下方式使用数据加载器。
class ReaderDataset(Dataset):
def __init__(self, examples):
# examples = a list of examples
# add your code
def __len__(self):
# return total dataset size
def __getitem__(self, index):
# write your code to return each batch item
train_dataset = ReaderDataset(train_examples)
train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=args.data_workers,
collate_fn=batchify,
pin_memory=args.cuda,
drop_last=args.parallel
)
# batchify is a custom function to prepare the mini-batches
推荐阅读
- javascript - nodejs - 未显示 console.log 消息
- macos - 在 macOS Big Sur 上安装 Qt4
- python - 如何将数据框的不同分类数据汇总到不同的列中
- reactjs - MUI Select 组件中的自定义文本
- pyspark - spark.sql() 中不等于什么
- turtle-graphics - 如何将正常图像转换为色差图像?
- matlab - 希尔伯特变换和 MATLAB 中的时间频谱
- python - Visual Studio 2019 Python 发布到文件系统失败
- android - 如何在android中调用ffmpeg视频命令
- c++ - 关于在图中使用顶点作为索引 c++ 为什么我们浪费空间