python - 在torch中使用random_split后如何获取train_dataset的路径名
问题描述
我有以下代码:
import torch, torchvision
root_dataset ="./data"
dataset = torchvision.datasets.folder.ImageFolder(root=root_dataset, transform=None, target_transform=None)
train_dataset, valid_dataset = torch.utils.data.dataset.random_split(
dataset=dataset,
lengths=[num_train, num_valid]
)
我的问题是:
train_dataset
使用random_split
in后如何获取路径的名称列表torch
?
谢谢你。
解决方案
路径(和标签)存储在dataset.imgs
. 例如,对于 imagenet:
In [ ]: print(dataset.imgs[0])
Out [ ]: ('/shareDB/imagenet/val/n01440764/ILSVRC2012_val_00000293.JPEG', 0)
拆分数据集后,每个拆分都指向原始数据集:
In [ ]: len(train_dataset.dataset), len(valid_dataset.dataset)
Out [ ]: (50000, 50000)
但是,每个拆分还包含为拆分选择的原始数据集的样本索引。您可以使用这些索引和原始数据集来获取为每个拆分选择的图像列表:
valid_imgs = [valid_dataset.dataset.imgs[i_] for i_ in valid_dataset.indices]
train_imgs = [train_dataset.dataset.imgs[i_] for i_ in train_dataset.indices]
推荐阅读
- mysql - SQLSTATE [23000]:完整性约束违规:1062 - laraclassified - 评论插件
- javascript - 我怎样才能让这个 Javascript 计算器工作?
- python - Python Pandas:在按其他列分组时创建累积平均值
- javascript - 如何根据下拉值过滤 jquery 数据表
- node.js - Django vs Node(Express) vs Flask 用于具有高安全性和实时性的 RESTful API
- html - 跨越 div 和 span——HTML5、css 和 WCAG
- python - 如何让 Sanic 响应 http 和 ws?
- python - 我应该使用 VGG19 的哪一层来提取特征
- github-pages - 您可以在 github pages 帐户上使用 Lottie-Files / bodymovin
- three.js - 如何在three.js中将json/js模型转换为gltf模型