首页 > 解决方案 > 您输入的数据用完了

问题描述

history = model.fit_generator(
      train_generator,
      steps_per_epoch=50,  
      epochs=10,
      verbose=1,
      validation_data = validation_generator,
      validation_steps=50)

tensorflow:你的输入数据用完了;中断训练。确保您的数据集或生成器至少可以生成steps_per_epoch * epochs批次(在本例中为 5000 个批次)。在构建数据集时,您可能需要使用 repeat() 函数。

标签: tensorflow2.0tensorflow-datasets

解决方案


要解决这个问题,我们需要注意两点:

  1. 如何定义批大小和batch_size 和steps_per_epoch。简单的答案是steps_per_epoch=total_train_size//batch_size
  2. 如何定义训练过程的最大时期数。这不像第一个那么简单。

大多数答案都涵盖了第一个主题,我没有找到第二个很好的答案,尝试解释如下:

我有一个包含 93 个数据样本的训练数据集,batch_size 为 32,所以对于第一个问题:steps_per_epoch=total_train_size//batch_size=93//32=2

对于第二个问题,这将取决于您的数据生成器可以提供多少不重复的批次,如果我有 93 个数据样本并且每个批次需要 32 个两个样本,那么每个时期都有 2 个训练步骤。您将有 93//2 = 46 个时期能够提供不重复的批次,时期 47 将导致此错误。

我没有找到 tensorflow 数据生成器的参考,所以这只是我的理解,如果有任何错误请纠正我,谢谢!


推荐阅读