python - Pytorch Dataloader 没有将数据拆分成批处理
问题描述
我有这样的数据集类:
class LoadDataset(Dataset):
def __init__(self, data, label):
self.data = data
self.label = label
def __len__(self):
dlen = len(self.data)
return dlen
def __getitem__(self, index):
return self.data, self.label
然后我加载具有 [485, 1, 32, 32] 形状的图像数据集
train_dataset = LoadDataset(xtrain, ytrain)
print(len(train_dataset))
# output 485
然后我加载数据DataLoader
train_loader = DataLoader(train_dataset, batch_size=32)
然后我迭代数据:
for epoch in range(num_epoch):
for inputs, labels in train_loader:
print(inputs.shape)
输出打印torch.Size([32, 485, 1, 32, 32])
,它应该是torch.Size([32, 1, 32, 32])
,
谁能帮我?
解决方案
该__getitem__
方法应该返回 1 条数据,你返回了所有的。
试试这个:
class LoadDataset(Dataset):
def __init__(self, data, label):
self.data = data
self.label = label
def __len__(self):
dlen = len(self.data)
llen = len(self.label) # different here
return min(dlen, llen) # different here
def __getitem__(self, index):
return self.data[index], self.label[index] # different here
推荐阅读
- arrays - 字节数组随机分布到一个字节数组中
- javascript - 原生定制 Android 应用所需的指导
- c# - 将变量标头转换为 JSON JArray
- angularjs - 将下拉列表的值传递给 Angular js
- java - 使用 swagger 从枚举值中记录字符串
- javascript - moment.js 比较未按预期工作
- java - 事务提交后如何更改隔离级别
- tensorflow - BatchNormalization 炸毁了 keras 模型
- java - gui.ava.html.Html2Image 不呈现文本溢出
- powershell - 今天日期和密码到期日期少于 14 天?