首页 > 解决方案 > pytorch 数据集对象在 for 循环中使用时如何知道它是否已经结束?

问题描述

我正在编写一个自定义 pytorch 数据集。在__init__数据集对象中加载一个包含特定数据的文件。但在我的程序中,我只希望访问部分数据(如果有帮助,可以实现训练/有效切割)。最初我认为这种行为是通过覆盖来控制的__len__,但事实证明修改__len__没有帮助。一个简单的例子如下:

from torch.utils.data import Dataset, DataLoader
import torch

class NewDS(Dataset):
    def __init__(self):
        self.data = torch.randn(10,2) # suppose there are 10 items in the data file
    
    def __len__(self):
        return len(self.data)-5 # But I only want to access the first 5 items
        
    def __getitem__(self, index):
        return self.data[index]

ds = NewDS()
for i, x in enumerate(ds):
    print(i)

输出是 0 到 9,而期望的行为是 0 到 4。

当在这样的 for 循环中使用时,这个数据集对象如何知道枚举已经结束?也欢迎任何其他实现类似效果的方法。

标签: pythonfor-loopiteratorpytorch

解决方案


您正在使用Dataset类创建自定义数据加载器,同时使用 for 循环枚举它。这不是它的工作方式。对于枚举,您必须通过DatasettoDataLoader类。你的代码会像这样很好地工作,

from torch.utils.data import Dataset, DataLoader
import torch

class NewDS(Dataset):
    def __init__(self):
        self.data = torch.randn(10,2) # suppose there are 10 items in the data file
    
    def __len__(self):
        return len(self.data)-5 # But I only want to access the first 5 items
        
    def __getitem__(self, index):
        return self.data[index]

ds = NewDS()
for i, x in range(len(ds)): #if you do dont want to use DataLoader, then dont use enumerate
    print(i, ds[i])
#output 
tensor([-0.2351,  1.3037])
tensor([ 0.4032, -0.2739])
tensor([-0.5687, -0.7300])
tensor([0.5418, 0.8572])
tensor([ 1.9973, -0.2939])

dl = DataLoader(ds, batch_size=1) # pass the ds object to DataLoader 

for i, x in enumerate(dl): # now you can use enumarate
    print(i, x)
#output
tensor([-0.2351,  1.3037])
tensor([ 0.4032, -0.2739])
tensor([-0.5687, -0.7300])
tensor([0.5418, 0.8572])
tensor([ 1.9973, -0.2939])

更多细节可以在这个官方pytorch 教程中阅读。


推荐阅读