首页 > 解决方案 > 在 PyTorch 中加载图像

问题描述

我是 PyTorch 的新手,正在研究 GAN 模型。我想加载我的图像数据集。使用 Keras 的方式是:

from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img

def load_images(path, size=(128,128)):
data_list = list()
# enumerate filenames in directory, assume all are images
for filename in listdir(path):
    # load and resize the image
    pixels = load_img(path + filename, target_size=size)
    # convert to numpy array
    pixels = img_to_array(pixels)
    # store. 
    data_list.append(pixels)
return asarray(data_list)
# dataset path
path = 'mypath/'
# load dataset A
dataA = load_images(path + 'A/')
dataAB = load_images(path + 'B/')

我想知道如何在 PyTorch 中做同样的事情。任何帮助表示赞赏。谢谢

标签: pytorch

解决方案


import torchvision, torch
from torchvision import datasets, models, transforms

def load_training(root_path, dir, batch_size, kwargs):

    transform = transforms.Compose(
        [transforms.Resize([256, 256]),
         transforms.RandomCrop(224),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor()])
    data = datasets.ImageFolder(root=root_path + dir, transform=transform)
    train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)

    return train_loader

我希望它会工作...


推荐阅读