python - DataLoader 使用 pytorch 创建数据集
问题描述
我有一个带有子文件夹(类)的文件夹,每个子文件夹中都有图像。
data
|_ classe1
|_ image1
|_ image2
|_ classe2
|_ ...
我的目标是创建一个数据集(训练 + 测试集)来使用 pytorch resnet 训练我的模型。我有一个错误,我不知道如何解决它,因为我不太了解 DataLoader 结构,所以我尝试了这个:
我有这个:
dataset = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['data']}
batch_size = 32
validation_split = .3
shuffle_dataset = True
random_seed= 42
# Creating data indices for training and validation splits:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
sampler=valid_sampler)
dataloaders_dict = {'train': train_loader, 'val': validation_loader}
但是当我尝试运行我的模型时,我遇到了这个错误:
Epoch 0/99
----------
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-79-8c30eb5e6a01> in <module>()
3
4 # Train and evaluate
----> 5 model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=False)
4 frames
<ipython-input-56-9421c2d39473> in train_model(model, dataloaders, criterion, optimizer, num_epochs, is_inception)
22
23 # Iterate over data.
---> 24 for inputs, labels in dataloaders[phase]:
25 inputs = inputs.to(device)
26 labels = labels.to(device)
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
361
362 def __next__(self):
--> 363 data = self._next_data()
364 self._num_yielded += 1
365 if self._dataset_kind == _DatasetKind.Iterable and \
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
401 def _next_data(self):
402 index = self._next_index() # may raise StopIteration
--> 403 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
404 if self._pin_memory:
405 data = _utils.pin_memory.pin_memory(data)
/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
KeyError: 0
有什么建议么?检测到任何错误?
解决方案
问题很可能来自您的第一行,您dataset
实际上是一个包含一个元素(pytorch 数据集)的字典。这会更好:
x = 'data'
dataset = datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
我假设data_transforms['data']
是预期类型的转换(详见此处)。
当 pytorch 尝试从仅包含一个元素的“数据集”(字典)中获取张量时,可能会产生 keyerror。
顺便说一句,我认为 pytorch 提供了 torch.utils.data.random_split 功能,因此您不必自己进行训练/测试拆分。你可能想查一下。
推荐阅读
- reactjs - 粘贴html时Slate js路径为空错误
- docker - 在与已运行容器相同的网络中运行 docker 容器
- machine-learning - CatBoostError:catboost/libs/train_lib/dir_helper.cpp:20:无法创建火车工作目录:
- python - Selenium,Python,选择特定 Xpath 的问题
- linux - 为什么 Rsyslog Server 在接收数据时不将文件创建到目录中?
- windows - 任务 ':app:processDebugResources' 执行失败。com.android.build.gradle.internal.tasks.Workers$ActionFacade (Flutter)
- c# - 在 ASP .Net Core MVC 中启动时执行数据库更新
- python - 将 html 标记数据拆分为多个列
- reactjs - react-slick滑块中达到数组长度时如何开始减少slideIndex?
- wordpress - 如果选择了下拉选项之一,我需要显示自定义分类的帖子