首页 > 解决方案 > 将 TFRecords 与 Tensorflow Estimators 一起使用时,是否有一种简单的方法来设置时期

问题描述

将 numpy 数组输入估计器时,有一种很好的方法来设置 epoch

  tf.estimator.inputs.numpy_input_fn(
     x,
     y=None,
     batch_size=128,
     num_epochs=1 ,
     shuffle=None,
     queue_capacity=1000,
     num_threads=1  
   )

但是我无法使用 TFRecords 找到类似的方法,大多数人似乎只是将它粘在一个循环中

 i = 0 
 while ( i < 100000):
   model.train(input_fn=input_fn, steps=100)

有没有一种干净的方法来显式设置带有估计器的 TFRecords 的时期数?

标签: tensorflowtfrecord

解决方案


您可以使用 设置纪元数dataset.repeat(num_epochs)。数据集管道输出一个数据集对象,一个批量大小的元组(特征,标签),输入到model.train()

dataset = tf.data.TFRecordDataset(file.tfrecords)
dataset = tf.shuffle().repeat()
...
dataset = dataset.batch()

为了使其工作,您设置model.train(steps=None, max_steps=None)在这种情况下,您让 Dataset API 通过在达到 num_epoch 时生成tf.errors.OutOfRange错误或异常来处理 epochs 计数。StopIteration


推荐阅读