python - pytorch 数据集对象在 for 循环中使用时如何知道它是否已经结束?
问题描述
我正在编写一个自定义 pytorch 数据集。在__init__
数据集对象中加载一个包含特定数据的文件。但在我的程序中,我只希望访问部分数据(如果有帮助,可以实现训练/有效切割)。最初我认为这种行为是通过覆盖来控制的__len__
,但事实证明修改__len__
没有帮助。一个简单的例子如下:
from torch.utils.data import Dataset, DataLoader
import torch
class NewDS(Dataset):
def __init__(self):
self.data = torch.randn(10,2) # suppose there are 10 items in the data file
def __len__(self):
return len(self.data)-5 # But I only want to access the first 5 items
def __getitem__(self, index):
return self.data[index]
ds = NewDS()
for i, x in enumerate(ds):
print(i)
输出是 0 到 9,而期望的行为是 0 到 4。
当在这样的 for 循环中使用时,这个数据集对象如何知道枚举已经结束?也欢迎任何其他实现类似效果的方法。
解决方案
您正在使用Dataset
类创建自定义数据加载器,同时使用 for 循环枚举它。这不是它的工作方式。对于枚举,您必须通过Dataset
toDataLoader
类。你的代码会像这样很好地工作,
from torch.utils.data import Dataset, DataLoader
import torch
class NewDS(Dataset):
def __init__(self):
self.data = torch.randn(10,2) # suppose there are 10 items in the data file
def __len__(self):
return len(self.data)-5 # But I only want to access the first 5 items
def __getitem__(self, index):
return self.data[index]
ds = NewDS()
for i, x in range(len(ds)): #if you do dont want to use DataLoader, then dont use enumerate
print(i, ds[i])
#output
tensor([-0.2351, 1.3037])
tensor([ 0.4032, -0.2739])
tensor([-0.5687, -0.7300])
tensor([0.5418, 0.8572])
tensor([ 1.9973, -0.2939])
dl = DataLoader(ds, batch_size=1) # pass the ds object to DataLoader
for i, x in enumerate(dl): # now you can use enumarate
print(i, x)
#output
tensor([-0.2351, 1.3037])
tensor([ 0.4032, -0.2739])
tensor([-0.5687, -0.7300])
tensor([0.5418, 0.8572])
tensor([ 1.9973, -0.2939])
更多细节可以在这个官方pytorch 教程中阅读。
推荐阅读
- python - 使用 Python 将计算的滤波器组存储到频谱图图像中
- c# - 将 C# 中的流转换为 C++ 中的 ofstream
- javascript - JQuery 弹出窗口未显示
- csrf - 了解 JHipster 中的 HTTP Session 安全机制
- javascript - 如何在javascript中将两个不同的对象数组组合成一个新的对象数组?
- r - 在ggplotly上添加第二个Y轴
- gnuplot - 多图中的行标题
- android - 如何获得 Android 虚拟设备的实际调整大小?
- python - 将部分 ascii 文本文件上传到 PostgreSQL 表中
- python - 负尺寸大小由 1 减去 2 导致的 'max_pooling2d_3/MaxPool' (op: 'MaxPool') 输入形状:[?,1,148,32]