首页 > 解决方案 > Tensorflow from_generator 拟合时输入不足

问题描述

我正在处理图像并尝试基于tf.keras.preprocessing.image.ImageDataGenerator类生成新图像以增加我的数据集。我遇到了问题,我的生成器在生成期间用完了示例model.fit

这是我用来生成数据和创建数据集的函数

def tfds_imgen(ds, imgen, batch_size, batches_per):
    for images, labels in ds:
        flow_ = imgen.flow(images, labels, batch_size=batch_size)
        for _ in range(batches_per):
            yield next(flow_)

imgen = tf.keras.preprocessing.image.ImageDataGenerator(rotation_range=60)

gen = tfds_imgen(
    train_ds.as_numpy_iterator(), imgen,
    batch_size=BATCH_SIZE, batches_per=1)

train_ds_gen = tf.data.Dataset.from_generator(
    lambda: gen,
    output_types = (tf.float32, tf.float32),
    output_shapes = ([None, 64, 64, 3], [None, 120]),
)

train_ds是形状(64,64,3),我希望为train_ds. 运行此代码会返回以下错误:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least steps_per_epoch * epochs.

我知道上面的代码只生成 1 个纪元的图像。为了解决这个问题,我尝试增加使用batches_per参数生成的图像数量,使其等于时期数。这解决了这个问题,但是,当我尝试将 与 连接时train_dstrain_ds_gen我仍然遇到类似的问题,因为 . 的train_ds_gen示例比 . 多 10 倍train_ds

我希望在每个时期都获得与真实图像一样多的生成图像。

如何设置生成器以生成足够的图像,同时拥有在生成图像和真实图像之间平衡的数据集。

标签: tensorflow

解决方案


推荐阅读