首页 > 解决方案 > 多输入 ImageDataGenerator

问题描述

我已经训练了 9 个不同的 UNet 进行图像分割,每个都在同一个数据集上,但使用不同的掩码,因为每个 UNet 分割图像的不同部分。我还尝试在多类场景中训练一个 UNet,但单个 UNet 的性能优于多类场景,因此我采用了单个 UNet 的路线。现在我正在尝试通过删除单个 UNet 的最后一个全连接层并在最后添加一些输出多类分割图的公共层来将单个 UNet 加入一个网络,希望这会改善我的分割结果。但是,我被困在数据生成器部分。目前我正在使用以下生成器来创建训练所需的图像和蒙版。

    # Function for image data generator
def dataGenerator(batch_size, path, image_folder,
                  mask_folder, aug_dict, image_color_mode="grayscale", mask_color_mode="grayscale",
                  image_save_prefix="image", mask_save_prefix="mask", num_class=10, save_to_dir=None,
                  target_size=(256, 256), seed=1):

    image_datagen = ImageDataGenerator(**aug_dict)
    mask_datagen = ImageDataGenerator(**aug_dict)

    image_generator = image_datagen.flow_from_directory(
        path,
        classes=[image_folder],
        class_mode=None,
        color_mode=image_color_mode,
        target_size=target_size,
        batch_size=batch_size,
        save_to_dir=save_to_dir,
        save_prefix=image_save_prefix,
        seed=seed)

    mask_generator = mask_datagen.flow_from_directory(
        path,
        classes=[mask_folder],
        class_mode=None,
        color_mode=mask_color_mode,
        target_size=target_size,
        batch_size=batch_size,
        save_to_dir=save_to_dir,
        save_prefix=mask_save_prefix,
        seed=seed)

    train_generator = zip(image_generator, mask_generator)

    for (img, mask) in train_generator:
        img, mask = adjustData(img, mask, num_class)
        yield img, mask

此函数生成我的数据生成器,然后输出图像和掩码以进行数据调整(标准化)。在此之后,我使用以下方法创建生成器:

trainGene = dataGenerator(batch_size,
                      path_train,
                      'images',
                      'masks',
                      data_gen_args,
                      save_to_dir=None)

最后,使用 model.fit 训练网络。

我被困在为 9 个输入复制这个。

标签: pythonkerasunity3d-unetdata-generation

解决方案


推荐阅读