python - Pytorch:在 dataloader.dataset 上使用 torch.utils.random_split() 后,数据中缺少批量大小
问题描述
我使用 random_split() 将我的数据划分为训练和测试,我观察到如果在创建数据加载器后进行随机拆分,则在从数据加载器获取一批数据时会丢失批量大小。
import torch
from torchvision import transforms, datasets
from torch.utils.data import random_split
# Normalize the data
transform_image = transforms.Compose([
transforms.Resize((240, 320)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
data = '/data/imgs/train'
def load_dataset():
data_path = data
main_dataset = datasets.ImageFolder(
root = data_path,
transform = transform_image
)
loader = torch.utils.data.DataLoader(
dataset = main_dataset,
batch_size= 64,
num_workers = 0,
shuffle= True
)
# Dataset has 22424 data points
trainloader, testloader = random_split(loader.dataset, [21000, 1424])
return trainloader, testloader
trainloader, testloader = load_dataset()
现在从训练和测试加载器中获取一批图像:
images, labels = next(iter(trainloader))
images.shape
# %%
len(trainloader)
# %%
images_test, labels_test = next(iter(testloader))
images_test.shape
# %%
len(testloader)
我得到的输出没有训练或测试批次的批量大小。输出调光应该是 [batch x channel x H x W] 但我得到 [channel x H x W]。
输出:
但是,如果我从数据集创建拆分,然后使用拆分创建两个数据加载器,我会在输出中获得批量大小。
def load_dataset():
data_path = data
main_dataset = datasets.ImageFolder(
root = data_path,
transform = transform_image
)
# Dataset has 22424 data points
train_data, test_data = random_split(main_dataset, [21000, 1424])
trainloader = torch.utils.data.DataLoader(
dataset = train_data,
batch_size= 64,
num_workers = 0,
shuffle= True
)
testloader = torch.utils.data.DataLoader(
dataset = test_data,
batch_size= 64,
num_workers= 0,
shuffle= True
)
return trainloader, testloader
trainloader, testloader = load_dataset()
在运行相同的 4 个命令以获得单个训练和测试批次时:
第一种方法是错误的吗?虽然长度显示数据已经被分割。那么为什么我看不到批量大小呢?
解决方案
第一种方法是错误的。
只有DataLoader
实例返回成批的项目。类似的Dataset
情况没有。
当你打电话给make_split
你时loader.dataset
,它只是对main_dataset
(不是 a DataLoader
)的引用。结果是trainloader
and testloader
are Dataset
s not DataLoader
s。事实上loader
,DataLoader
当您从load_dataset
.
第二个版本是您应该做的以获得两个单独DataLoader
的 s。
推荐阅读
- c - 如何用另一个子字符串替换字符串的一部分
- r - How to match linked steps indicated by two columns of a data.table
- c# - C# 自定义 Json.NET 列表序列化
- php - 有没有办法在使用 php 检索后使用存储在 mysql 中的 PHP 代码?
- r - 在R中对列进行数字排序
- c - 为什么函数会被跳过而不被读取?
- javascript - 如何获取 DatePicker Cypress 的值
- javascript - 运算符 < 不能应用于 Number 和 boolean 类型
- vba - 子窗体阻塞 Recordset.AddNew(错误 3027)
- typescript - 我可以将类型用作值(或从构造函数参数正确推断泛型类类型)吗?