tensorflow2.0 - 您输入的数据用完了
问题描述
history = model.fit_generator(
train_generator,
steps_per_epoch=50,
epochs=10,
verbose=1,
validation_data = validation_generator,
validation_steps=50)
tensorflow:你的输入数据用完了;中断训练。确保您的数据集或生成器至少可以生成steps_per_epoch * epochs
批次(在本例中为 5000 个批次)。在构建数据集时,您可能需要使用 repeat() 函数。
解决方案
要解决这个问题,我们需要注意两点:
- 如何定义批大小和batch_size 和steps_per_epoch。简单的答案是steps_per_epoch=total_train_size//batch_size
- 如何定义训练过程的最大时期数。这不像第一个那么简单。
大多数答案都涵盖了第一个主题,我没有找到第二个很好的答案,尝试解释如下:
我有一个包含 93 个数据样本的训练数据集,batch_size 为 32,所以对于第一个问题:steps_per_epoch=total_train_size//batch_size=93//32=2
对于第二个问题,这将取决于您的数据生成器可以提供多少不重复的批次,如果我有 93 个数据样本并且每个批次需要 32 个两个样本,那么每个时期都有 2 个训练步骤。您将有 93//2 = 46 个时期能够提供不重复的批次,时期 47 将导致此错误。
我没有找到 tensorflow 数据生成器的参考,所以这只是我的理解,如果有任何错误请纠正我,谢谢!