首页 > 解决方案 > 如何从 pytorch DataLoader 获取特定样本?

问题描述

在 Pytorch 中,有没有办法使用类加载特定的单个样本torch.utils.data.DataLoader?我想用它做一些测试。

教程使用

trainloader = torch.utils.data.DataLoader(...)
images, labels = next(iter(trainloader))

获取随机批次的样本。有没有办法,使用DataLoader,得到一个特定的样本?

干杯

标签: pytorch

解决方案


  • 关闭shuffle输入DataLoader
  • 用于batch_size计算您要查找的所需样品所属的批次
  • 迭代到所需的批次

代码

import torch 
import numpy as np
import itertools

X= np.arange(100)
batch_size = 2

dataloader = torch.utils.data.DataLoader(X, batch_size=batch_size, shuffle=False)
sample_at = 5
k = int(np.floor(sample_at/batch_size))

my_sample = next(itertools.islice(dataloader, k, None))
print (my_sample)

输出:

tensor([4, 5])

推荐阅读