python - 优化 pytorch 数据加载器以读取全高清图像中的小补丁
问题描述
我正在使用 PyTorch 框架训练我的神经网络。数据是全高清图像 (1920x1080)。但是在每次迭代中,我只需要从这些图像中裁剪出一个随机的 256x256 补丁。我的网络相对较小(5 个卷积层),因此瓶颈是由加载数据引起的。我在下面提供了我当前的代码。有什么方法可以优化加载数据并加快训练速度?
代码:
from pathlib import Path
import numpy
import skimage.io
import torch.utils.data as data
import Imath
import OpenEXR
class Ours(data.Dataset):
"""
Loads patches of resolution 256x256. Patches are selected such that they contain atleast 1 unknown pixel
"""
def __init__(self, data_dirpath, split_name, patch_size):
super(Ours, self).__init__()
self.dataroot = Path(data_dirpath) / split_name
self.video_names = []
for video_path in sorted(self.dataroot.iterdir()):
for i in range(4):
for j in range(11):
view_num = i * 12 + j
self.video_names.append((video_path.stem, view_num))
self.patch_size = patch_size
return
def __getitem__(self, index):
video_name, view_num = self.video_names[index]
patch_start_pt = (numpy.random.randint(1080), numpy.random.randint(1920))
frame1_path = self.dataroot / video_name / f'render/rgb/{view_num + 1:04}.png'
frame2_path = self.dataroot / video_name / f'render/rgb/{view_num + 2:04}.png'
depth_path = self.dataroot / video_name / f'render/depth/{view_num + 1:04}.exr'
mask_path = self.dataroot / video_name / f'render/masks/{view_num + 1:04}.png'
frame1 = self.get_image(frame1_path, patch_start_pt)
frame2 = self.get_image(frame2_path, patch_start_pt)
mask = self.get_mask(mask_path, patch_start_pt)
depth = self.get_depth(depth_path, patch_start_pt, mask)
data_dict = {
'frame1': frame1,
'frame2': frame2,
'mask': mask,
'depth': depth,
}
return data_dict
def __len__(self):
return len(self.video_names)
@staticmethod
def get_mask(path: Path, patch_start_point: tuple):
h, w = patch_start_point
mask = skimage.io.imread(path.as_posix())[h:h + self.patch_size, w:w + self.patch_size][None]
return mask
def get_image(self, path: Path, patch_start_point: tuple):
h, w = patch_start_point
image = skimage.io.imread(path.as_posix())
image = image[h:h + self.patch_size, w:w + self.patch_size, :3]
image = image.astype(numpy.float32) / 255 * 2 - 1
image_cf = numpy.moveaxis(image, [0, 1, 2], [1, 2, 0])
return image_cf
def get_depth(self, path: Path, patch_start_point: tuple, mask: numpy.ndarray):
h, w = patch_start_point
exrfile = OpenEXR.InputFile(path.as_posix())
raw_bytes = exrfile.channel('B', Imath.PixelType(Imath.PixelType.FLOAT))
depth_vector = numpy.frombuffer(raw_bytes, dtype=numpy.float32)
height = exrfile.header()['displayWindow'].max.y + 1 - exrfile.header()['displayWindow'].min.y
width = exrfile.header()['displayWindow'].max.x + 1 - exrfile.header()['displayWindow'].min.x
depth = numpy.reshape(depth_vector, (height, width))
depth = depth[h:h + self.patch_size, w:w + self.patch_size]
depth = depth[None]
depth = depth.astype(numpy.float32)
depth = depth * mask
return depth
最后,我正在创建一个 DataLoader,如下所示:
train_data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)
到目前为止我已经尝试过:
- 我已经搜索是否可以读取图像的一部分。不幸的是,我没有得到任何线索。看起来 python 库读取了完整的图像。
- 我计划从单个图像中读取更多补丁,以便我需要读取更少的图像。但在 PyTorch 框架中,
get_item()
函数必须返回单个样本,而不是批次。所以,在每个get_item()
我只能读取一个补丁。 - 我计划通过以下方式规避此问题:读取 4 个补丁
get_item()
并返回形状补丁(4,3,256,256)
而不是(3,256,256)
. 稍后当我使用数据加载器读取批次时,我会得到一批形状(BS,4,3,256,256)
而不是(BS,3,256,256)
. 然后我可以连接数据dim=1
以转换(BS,4,3,256,256)
为(BS*4,3,256,256)
. 因此,我可以将batch_size
(BS
) 减少 4,并希望这会将数据加载速度提高 4 倍。
还有其他选择吗?我愿意接受各种建议。谢谢!
解决方案
推荐阅读
- python - 设计师重叠问题
- flutter - 'https://www.googleapis.com/auth/contacts.readonly' 不工作显示加载...在谷歌标志
- javascript - 试图通过 puppeteer 在终端上打印控制台开发工具,我得到一个不完整的输出
- python - 基于 2 个数据帧的 pandas 的高效数据操作
- stored-procedures - 获取程序的主体
- python - 如何在Networkx中从叶子到根的所有路径中的每个叶子连接到每个节点
- javascript - 使用 ckeditor 5 自定义构建和 vuejs 2
- indexing - 索引匹配的多个条件
- minio - 当一个minio节点的磁盘满了怎么办?
- php - 如何使用 Artifactory 作为我的 PHP 作曲家依赖项的缓存代理?