首页 > 解决方案 > .h5 文件的自定义 pytorch 数据集类中有问题

问题描述

class HDF5Dataset(torch.utils.data.Dataset):

    def __setup_files(self):
        files = glob.glob(os.path.join(self.dir_path,'**/*.h5'))
        return files

    def __init__(self, dir_path, IMG_SIZE):
        self.dir_path = dir_path
        self.IMG_SIZE = IMG_SIZE
        self.files = self.__setup_files()
        self.length = len(self.files)
        self.transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    def __getitem__(self, idx):
        record = self.files[idx]
        h5 = h5py.File(record , 'r')
        image = h5['data'].value
        label = h5['label'].value
        h5.close()
        image = np.asarray(image)
        image = Image.fromarray(image.astype('uint8'), 'RGB')

        return self.transform(image), label

    def __len__(self):
        return self.length

这是我的自定义数据集类,我正在尝试以递归方式加载目录中的每个 h5 文件。

我认为 def getitem 有问题,但我不确定它是什么。

当我尝试加载这个

dataloaders['train'] = torch.utils.data.DataLoader(datasets['train'],
                                              batch_size=batch_size, shuffle=True, pin_memory=True, 
                                              num_workers=12)

而这段代码,

inputs, classes = next(iter(dataloaders['train']))

它给出类型错误:

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/_utils/collate.py", line 62, in default_collate
    raise TypeError(default_collate_err_msg_format.format(elem.dtype))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object

我想确切地知道如何在 pytorch 中为 h5 文件构建自定义数据集以及如何加载它们。

谢谢!

标签: pythonpytorchh5pydataloader

解决方案


推荐阅读