python - .h5 文件的自定义 pytorch 数据集类中有问题
问题描述
class HDF5Dataset(torch.utils.data.Dataset):
def __setup_files(self):
files = glob.glob(os.path.join(self.dir_path,'**/*.h5'))
return files
def __init__(self, dir_path, IMG_SIZE):
self.dir_path = dir_path
self.IMG_SIZE = IMG_SIZE
self.files = self.__setup_files()
self.length = len(self.files)
self.transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def __getitem__(self, idx):
record = self.files[idx]
h5 = h5py.File(record , 'r')
image = h5['data'].value
label = h5['label'].value
h5.close()
image = np.asarray(image)
image = Image.fromarray(image.astype('uint8'), 'RGB')
return self.transform(image), label
def __len__(self):
return self.length
这是我的自定义数据集类,我正在尝试以递归方式加载目录中的每个 h5 文件。
我认为 def getitem 有问题,但我不确定它是什么。
当我尝试加载这个
dataloaders['train'] = torch.utils.data.DataLoader(datasets['train'],
batch_size=batch_size, shuffle=True, pin_memory=True,
num_workers=12)
而这段代码,
inputs, classes = next(iter(dataloaders['train']))
它给出类型错误:
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
return [default_collate(samples) for samples in transposed]
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
return [default_collate(samples) for samples in transposed]
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/_utils/collate.py", line 62, in default_collate
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object
我想确切地知道如何在 pytorch 中为 h5 文件构建自定义数据集以及如何加载它们。
谢谢!
解决方案
推荐阅读
- react-native - 文件未从模拟器发送
- android - 在 Android sdk 30 中存储我的应用程序创建的大文件的最佳方式是什么?
- css - fiori基础知识中警报组件中的覆盖图标
- python - 当我的文件名为“web_automation\pra.txt”时,除了更改文件名外,我还能做什么?
- flutter - 错误:使用 EasyLocalization Flutter 更改翻译文件的路径
- php - 我得到这个页面不起作用,并且它在 PHP 中重定向了太多次
- kubernetes - 使用 kfp.dls.containerOp() 在 Kubeflow Pipelines 上运行多个脚本
- python - 如何计算这种格式的数据帧的方差?
- c++ - 如何引用“timeElapsed”中存储的内容以在另一个函数中进行比较?
- javascript - Webpack 5 React 组件库 UMD 与 SourceMaps 捆绑