machine-learning - PyTorch 增强
问题描述
我是机器学习和 pytorch 的新手。我正在使用 imgaug 库进行图像增强(https://github.com/aleju/imgaug)
我有这个代码:
class ImgAugTransform:
def __init__(self):
self.aug = seq = iaa.Sequential(
[
# Apply the following augmenters to most images
iaa.Fliplr(0.5), # horizontally flip 50% of all images
iaa.Flipud(0.2), # vertically flip 20% of all images
random_aug_use(iaa.CropAndPad( # crop images by -5% to 10% of their height/width
percent=(-0.1, 0.2),
pad_mode=ia.ALL,
pad_cval=(0.,255)
)),
random_aug_use(iaa.Affine(
scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}, # scale images to 80-120% of their size, individually per axis
translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}, # translate by -20 to +20 percent (per axis)
rotate=(-45, 45), # rotate by -45 to +45 degrees
shear=(-16, 16), # shear by -16 to +16 degrees
order=[0, 1], # use nearest neighbour or bilinear interpolation (fast)
cval=(0, 255), # if mode is constant, use a cval between 0 and 255
mode=ia.ALL # use any of scikit-image's warping modes (see 2nd image from the top for examples)
))
],
random_order=True)
def __call__(self, img):
img = np.array(img)
return self.aug.augment_image(img)
train_transforms = ImgAugTransform()
train_dataset = torchvision.datasets.ImageFolder(train_dir, train_transforms)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=batch_size)
所以现在我不能这样做:
X_batch, y_batch = next(iter(train_dataloader))
我得到错误:
ValueError: some of the strides of a given numpy array are negative. This is currently not supported, but will be added in future releases.
解决方案
我也遇到了这个错误。对我有用的解决方案是:
def __call__(self, img):
img = np.array(img)
return self.aug.augment_image(img).copy()
但是,如果您正在与您一起创作imgaug
,则torchvision.transforms
可以执行以下操作:
def __call__(self, img):
img = self.aug.augment_image(np.array(img))
transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(self.normalization[0],
self.normalization[1]),
])
return transforms(img.copy())
推荐阅读
- firebase - Arduino wifi > firebase 和超过 1 个 LED 控制
- python - 美丽的汤没有为使用相同 URL 的两个不同程序找到相同的结果
- python - 年龄输入,制作有年龄和年龄组的小系统
- sql-server - 在 SQL 中,循环子项以在每个级别获得销售额的最佳方法是什么?
- twitter-bootstrap - 如何使用 ViewChild 在 ngAfterViewInit 方法上打开模式?
- awk - 仅将大写单词转换为小写以取消名词的大写
- ffmpeg - 如何处理 ffplay 播放 iPhone 的视频太慢?
- javascript - 我们是否需要为所有不同的第三方脚本添加 noscript 标签
- rust - 为什么我会收到错误“没有为类型 Option 找到名为 collect 的方法”?
- r - 添加列名作为特定列的前缀