tensorflow - Tensorflow from_generator 拟合时输入不足
问题描述
我正在处理图像并尝试基于tf.keras.preprocessing.image.ImageDataGenerator
类生成新图像以增加我的数据集。我遇到了问题,我的生成器在生成期间用完了示例model.fit
这是我用来生成数据和创建数据集的函数
def tfds_imgen(ds, imgen, batch_size, batches_per):
for images, labels in ds:
flow_ = imgen.flow(images, labels, batch_size=batch_size)
for _ in range(batches_per):
yield next(flow_)
imgen = tf.keras.preprocessing.image.ImageDataGenerator(rotation_range=60)
gen = tfds_imgen(
train_ds.as_numpy_iterator(), imgen,
batch_size=BATCH_SIZE, batches_per=1)
train_ds_gen = tf.data.Dataset.from_generator(
lambda: gen,
output_types = (tf.float32, tf.float32),
output_shapes = ([None, 64, 64, 3], [None, 120]),
)
这train_ds
是形状(64,64,3)
,我希望为train_ds
. 运行此代码会返回以下错误:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least steps_per_epoch * epochs
.
我知道上面的代码只生成 1 个纪元的图像。为了解决这个问题,我尝试增加使用batches_per
参数生成的图像数量,使其等于时期数。这解决了这个问题,但是,当我尝试将 与 连接时train_ds
,train_ds_gen
我仍然遇到类似的问题,因为 . 的train_ds_gen
示例比 . 多 10 倍train_ds
。
我希望在每个时期都获得与真实图像一样多的生成图像。
如何设置生成器以生成足够的图像,同时拥有在生成图像和真实图像之间平衡的数据集。
解决方案
推荐阅读
- r - How to extract all words after the nth word from string in R?
- python - Python中map函数的行为
- android - 如何将所有按钮设置为无边框 Android
- python - Python 或 bash 脚本:如果模式在两个相同标记之间的行中,则删除行和第一个标记
- iphone - Swift 工具栏按钮仅在一个场景中触发 exec_bad_instruction
- android - Handling Entities and Pojos
- r - 从 R 函数生成 .md 降价文件
- go - Vault Token Helper not being detected?
- php - 未定义的偏移量#####
- android - 等待蓝牙启用