首页 > 解决方案 > 优化 pytorch 数据加载器以读取全高清图像中的小补丁

问题描述

我正在使用 PyTorch 框架训练我的神经网络。数据是全高清图像 (1920x1080)。但是在每次迭代中,我只需要从这些图像中裁剪出一个随机的 256x256 补丁。我的网络相对较小(5 个卷积层),因此瓶颈是由加载数据引起的。我在下面提供了我当前的代码。有什么方法可以优化加载数据并加快训练速度?

代码

from pathlib import Path

import numpy
import skimage.io
import torch.utils.data as data

import Imath
import OpenEXR


class Ours(data.Dataset):
    """
    Loads patches of resolution 256x256. Patches are selected such that they contain atleast 1 unknown pixel
    """

    def __init__(self, data_dirpath, split_name, patch_size):
        super(Ours, self).__init__()
        self.dataroot = Path(data_dirpath) / split_name
        self.video_names = []
        for video_path in sorted(self.dataroot.iterdir()):
            for i in range(4):
                for j in range(11):
                    view_num = i * 12 + j
                    self.video_names.append((video_path.stem, view_num))
        self.patch_size = patch_size
        return

    def __getitem__(self, index):
        video_name, view_num = self.video_names[index]

        patch_start_pt = (numpy.random.randint(1080), numpy.random.randint(1920))

        frame1_path = self.dataroot / video_name / f'render/rgb/{view_num + 1:04}.png'
        frame2_path = self.dataroot / video_name / f'render/rgb/{view_num + 2:04}.png'
        depth_path = self.dataroot / video_name / f'render/depth/{view_num + 1:04}.exr'
        mask_path = self.dataroot / video_name / f'render/masks/{view_num + 1:04}.png'
        frame1 = self.get_image(frame1_path, patch_start_pt)
        frame2 = self.get_image(frame2_path, patch_start_pt)
        mask = self.get_mask(mask_path, patch_start_pt)
        depth = self.get_depth(depth_path, patch_start_pt, mask)

        data_dict = {
            'frame1': frame1,
            'frame2': frame2,
            'mask': mask,
            'depth': depth,
        }
        return data_dict

    def __len__(self):
        return len(self.video_names)

    @staticmethod
    def get_mask(path: Path, patch_start_point: tuple):
        h, w = patch_start_point
        mask = skimage.io.imread(path.as_posix())[h:h + self.patch_size, w:w + self.patch_size][None]
        return mask

    def get_image(self, path: Path, patch_start_point: tuple):
        h, w = patch_start_point
        image = skimage.io.imread(path.as_posix())
        image = image[h:h + self.patch_size, w:w + self.patch_size, :3]
        image = image.astype(numpy.float32) / 255 * 2 - 1
        image_cf = numpy.moveaxis(image, [0, 1, 2], [1, 2, 0])
        return image_cf

    def get_depth(self, path: Path, patch_start_point: tuple, mask: numpy.ndarray):
        h, w = patch_start_point

        exrfile = OpenEXR.InputFile(path.as_posix())
        raw_bytes = exrfile.channel('B', Imath.PixelType(Imath.PixelType.FLOAT))
        depth_vector = numpy.frombuffer(raw_bytes, dtype=numpy.float32)
        height = exrfile.header()['displayWindow'].max.y + 1 - exrfile.header()['displayWindow'].min.y
        width = exrfile.header()['displayWindow'].max.x + 1 - exrfile.header()['displayWindow'].min.x
        depth = numpy.reshape(depth_vector, (height, width))

        depth = depth[h:h + self.patch_size, w:w + self.patch_size]
        depth = depth[None]
        depth = depth.astype(numpy.float32)
        depth = depth * mask
        return depth

最后,我正在创建一个 DataLoader,如下所示:

train_data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)

到目前为止我已经尝试过

  1. 我已经搜索是否可以读取图像的一部分。不幸的是,我没有得到任何线索。看起来 python 库读取了完整的图像。
  2. 我计划从单个图像中读取更多补丁,以便我需要读取更少的图像。但在 PyTorch 框架中,get_item()函数必须返回单个样本,而不是批次。所以,在每个get_item()我只能读取一个补丁。
  3. 我计划通过以下方式规避此问题:读取 4 个补丁get_item()并返回形状补丁(4,3,256,256)而不是(3,256,256). 稍后当我使用数据加载器读取批次时,我会得到一批形状(BS,4,3,256,256)而不是(BS,3,256,256). 然后我可以连接数据dim=1以转换(BS,4,3,256,256)(BS*4,3,256,256). 因此,我可以将batch_size( BS) 减少 4,并希望这会将数据加载速度提高 4 倍。

还有其他选择吗?我愿意接受各种建议。谢谢!

标签: pythonperformancepytorchdataloaderpytorch-dataloader

解决方案


推荐阅读