首页 > 解决方案 > 在torch中使用random_split后如何获取train_dataset的路径名

问题描述

我有以下代码:

import torch, torchvision
root_dataset ="./data"
dataset = torchvision.datasets.folder.ImageFolder(root=root_dataset, transform=None, target_transform=None)
train_dataset, valid_dataset = torch.utils.data.dataset.random_split(
    dataset=dataset,
    lengths=[num_train, num_valid]
)

我的问题是:

train_dataset使用random_splitin后如何获取路径的名称列表torch

谢谢你。

标签: pythonpytorch

解决方案


路径(和标签)存储在dataset.imgs. 例如,对于 imagenet:

In [ ]: print(dataset.imgs[0])
Out [ ]: ('/shareDB/imagenet/val/n01440764/ILSVRC2012_val_00000293.JPEG', 0) 

拆分数据集后,每个拆分都指向原始数据集:

In [ ]: len(train_dataset.dataset), len(valid_dataset.dataset)
Out [ ]: (50000, 50000)

但是,每个拆分还包含为拆分选择的原始数据集的样本索引。您可以使用这些索引和原始数据集来获取为每个拆分选择的图像列表:

valid_imgs = [valid_dataset.dataset.imgs[i_] for i_ in valid_dataset.indices]
train_imgs = [train_dataset.dataset.imgs[i_] for i_ in train_dataset.indices]

推荐阅读