首页 > 解决方案 > PyTorch 增强

问题描述

我是机器学习和 pytorch 的新手。我正在使用 imgaug 库进行图像增强(https://github.com/aleju/imgaug

我有这个代码:

class ImgAugTransform:
    def __init__(self):
        self.aug = seq = iaa.Sequential(
            [
                # Apply the following augmenters to most images
                iaa.Fliplr(0.5), # horizontally flip 50% of all images
                iaa.Flipud(0.2),  # vertically flip 20% of all images
                random_aug_use(iaa.CropAndPad( # crop images by -5% to 10% of their height/width
                    percent=(-0.1, 0.2),
                    pad_mode=ia.ALL,
                    pad_cval=(0.,255)
                )),
                random_aug_use(iaa.Affine(
                    scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}, # scale images to 80-120% of their size, individually per axis
                    translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}, # translate by -20 to +20 percent (per axis)
                    rotate=(-45, 45), # rotate by -45 to +45 degrees
                    shear=(-16, 16), # shear by -16 to +16 degrees
                    order=[0, 1], # use nearest neighbour or bilinear interpolation (fast)
                    cval=(0, 255), # if mode is constant, use a cval between 0 and 255
                    mode=ia.ALL # use any of scikit-image's warping modes (see 2nd image from the top for examples)
                ))
            ], 
            random_order=True)
    
    def __call__(self, img):
        img = np.array(img)
        return self.aug.augment_image(img)

train_transforms = ImgAugTransform()

train_dataset = torchvision.datasets.ImageFolder(train_dir, train_transforms)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=batch_size)

所以现在我不能这样做:

X_batch, y_batch = next(iter(train_dataloader))

我得到错误:

ValueError: some of the strides of a given numpy array are negative. This is currently not supported, but will be added in future releases.

标签: machine-learningdeep-learningcomputer-visionpytorch

解决方案


我也遇到了这个错误。对我有用的解决方案是:

def __call__(self, img):
        img = np.array(img)
        return self.aug.augment_image(img).copy()

但是,如果您正在与您一起创作imgaug,则torchvision.transforms可以执行以下操作:

def __call__(self, img):
  img = self.aug.augment_image(np.array(img))
  transforms = torchvision.transforms.Compose([
               torchvision.transforms.ToTensor(),
               torchvision.transforms.Normalize(self.normalization[0], 
               self.normalization[1]),
              ])
  return transforms(img.copy())

推荐阅读