python-3.x - 在 Pytorch 中为 HDF5 文件创建数据集和数据加载器时遇到问题:没有足够的值来解包(预期 2,得到 1)
问题描述
使用 Torchvision 的数据集后,我尝试在 Pytorch 中加载 HDF5 文件,但没有成功。我读到我应该定义自己的 Dataset 和 Dataloader 类,其中包含 getitem 以启用索引和 len 以返回数据集的长度。另外,我应该定义转换,因为 pytorch 的默认选项需要 PIL 图像。我尝试这样做,但我收到错误“ValueError:没有足够的值来解包(预期 2,得到 1)”我做错了什么?
#PyTorch packages
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, datasets
from torch import optim
from torch.autograd import Variable
from torch.utils import data
import h5py
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(0)
#open training file
with h5py.File('train_catvnoncat.h5', 'r') as hdf:
ls = list(hdf.keys())
print('List of datasets in this file: \n', ls)
data = hdf.get('dataset')
dataset1 = np.array(data)
print('Shape of dataset1: \n', dataset1.shape)
length = len(h5py.File('train_catvnoncat.h5', 'r'))
print(length)
#image size (64,64,3) 64*64*3=12,288.
#209 training examples
#50 test examples
# Example of a picture
#image size (64,64,3) 64*64*3=12,288.
#209 training examples
#50 test examples
#Def the dataloader for h5 files:
class HDF5Dataset(Dataset):
def __init__(self, h5_path):
self.h5_path = '/Users/teff/Downloads/'
self.train = train_catvnoncat.h5(h5_path, 'r')
self.train = test_catvnoncat.h5(h5_path, 'r')
self.length = len(h5py.File(h5_path, 'r'))
# self.transform = transform_hdf5 #I need to define the "transformToTensor"
def __getitem__(self, index): #to enable indexing
record = self.train[str(index)]
return (
record['X'].value,
record['y'].value,
)
def __len__(self): #returns the lenght of the dataset
return self.length
train_loader = torch.utils.data.DataLoader('train_catvnoncat.h5', shuffle=True)
test_loader = torch.utils.data.DataLoader('test_catvnoncat.h5', shuffle=True)
解决方案
您的数据集应如下所示:
import torchvision.transforms as transforms
class HDF5Dataset(Dataset):
transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def __init__(self, h5_path):
self.h5_path = '/Users/teff/Downloads/'
self.train = test_catvnoncat.h5(h5_path, 'r')
self.length = len(h5py.File(h5_path, 'r'))
def __getitem__(self, index): #to enable indexing
record = self.train[str(index)]
image = record['X'].value
# transform to PIL image
image = Image.fromarray(pixels.astype('uint8'), 'RGB') # assume your data is uint8 rgb
label = record['y'].value
# transformation here
# torchvision PIL transformations accepts one image as input
image = self.transform(image)
return (
image,
label,
)
def __len__(self):
return self.length
PS 看看关于 pytorch 数据加载的精彩教程。
推荐阅读
- javascript - 加载文件夹内容
- tfs - 在 TFS 2018 中的 Visual Studio 构建任务期间指定 DesktopPackageLocation 中的项目名称
- javascript - ES6 将组件包装在组件内
- jquery - 递归使用 jQuery .append()
- javascript - 创建用户在firebase中创建具有相同信息的身份不明的孩子
- c++ - 使用 FreeType 和 SDL2 渲染文本会生成白色方块而不是字符
- iis - IIS 中动态内容的压缩似乎不起作用
- firebase - Firebase Cloud Functions 中的用户注册源
- javascript - 传递一个 prop,传递时计算其值
- javascript - 无法在聚合物中获取视频元素