首页 > 解决方案 > Pytorch Data Loader 将图像连接到输入图像

问题描述

在 PyTorch 数据加载器中,我如何将图像(比如 x.jpg)以带内方式连接到每个输入图像。即,实际上我将有 4 波段输入(3 波段输入 jpg 和 1 波段 x.jpg。如何实现它。

请在下面找到我当前数据加载器的示例,以加载图像。为此,我想将 x.jpg 添加到“图像”(即输入图像,而不是掩码)

from PIL import Image

class lakeDataSet(Dataset):
   def __init__(self, root, transform):
   super().__init__()
   self.root = root
   self.img_dir = os.path.join(root,'image-c3/c3-crop')   #9UAV
   self.mask_dir = os.path.join(root,'label-c3/c3-crop')
   # self.mask_dir = os.path.join(root,'test')
   self.files = [fname for fname in os.listdir(self.img_dir) if fname.endswith('.jpg')]
self.transform = transform

   def __len__(self):
     return len(self.files)

   def __getitem__(self,I):
     fname = self.files[i]
     img_path = os.path.join(self.img_dir, fname)
     mask_path = os.path.join(self.mask_dir, fname)

     img = self.transform(Image.open(img_path))
     mask = self.transform(Image.open(mask_path))
     return img, mask

标签: pythoncomputer-visionpytorch

解决方案


我想self.transform已经有了ToTensor。否则,您也应该指定它。

然后你可以连接第一个维度。像

x_jpg = self.transform(Image.open('x.jpg'))
img = torch.cat((img, x_jpg), 0)

必须只有 1 个通道,如果它是 RGB ,x.jpg那么显然它将变成 6 个通道而不是 4 个。


推荐阅读