首页 > 解决方案 > PyTorch / SpykeTorch 中的 DataLouder:转换数据提取的问题

问题描述

同事们,我在 PyTorch 和 SpykeTorch(基于 PyTorch)中使用神经网络,我需要创建图像数据集并将它们放在 DataLouders 中以进行进一步处理。完整过程如下:
1.生成张量,
2.使用torchvision.transforms.ToPILImage()对其进行变换,
3.将创建的图像保存到一个目录,
4.基于具有图像变换的目录创建一个ImageFolder(使用过滤器) ,
5. 从 ImageFolder 创建 DataLoader。

image_set = torch.rand([10000, 28, 28], dtype=torch.float)   

path = './data/images/'  
os.makedirs(path)  
        
tTPI = torchvision.transforms.ToPILImage()   
    
for i in range(n):   
    single_image = tTPI(image_set[i])     
    image_file = path+f'pic_{i}.jpg'   
    saved_image = single_image.save(f'{path}pic_{i}.jpg')    

kernels = [ SpykeTorch.utils.DoGKernel(7,1,2),
            SpykeTorch.utils.DoGKernel(7,2,1)]
filter = SpykeTorch.utils.Filter(kernels, padding = 3, thresholds = 50)
s1 = S1Transform(filter)

RandomImageFolder = ImageFolder(root='./data/', transform = s1)  
RandomDataLoader = DataLoader(RandomImageFolder, batch_size=len(RandomImageFolder))   

此外,来自 DataLoader 的数据用于工作(例如,它被神经网络识别)。

for data, target in RandomDataLoader:
    prediction_X, prediction_y = predict(model, data, target)

问题是,从DataLoader拉取数据和标签时,出现错误:

RuntimeError: Given groups = 1, weight of size [2, 1, 7, 7], expected input [1, 3, 28, 28] to have 1 channels, but got 3 channels instead

从维度 [1, 2, 7, 7] 来看,错误发生在第 4 阶段,其中使用一组过滤器进行转换。
但是,在这种情况下使用不同的过滤器集不会导致任何错误。
如何在不更换过滤器的情况下解决问题?

标签: pytorchpytorch-dataloader

解决方案


问题是生成的 * .jpg 文件在加载到 ImageFolder 时被视为 RGB,并且大小为 [1, 3, 28, 28] 而不是 [1, 1, 28, 28]。

我在变换中添加:

from PIL import ImageOps
 
gray_image = ImageOps.grayscale(image)

推荐阅读