首页 > 解决方案 > 使用 TensorFlow 训练 CNN 时如何修复“OutOfRangeError:序列结束”错误?

问题描述

我正在尝试使用我自己的数据集训练 CNN。我一直在使用 tfrecord 文件和 tf.data.TFRecordDataset API 来处理我的数据集。它适用于我的训练数据集。但是当我尝试批处理我的验证数据集时,出现了“OutOfRangeError:序列结束”的错误。上网浏览后,我认为问题是验证集的批大小引起的,我一开始设置为 32。但是在我将其更改为 2 之后,代码运行了 9 个 epoch,并且再次引发了错误。

我使用输入函数来处理数据集,代码如下:

def input_fn(is_training, filenames, batch_size, num_epochs=1, num_parallel_reads=1):
    dataset = tf.data.TFRecordDataset(filenames,num_parallel_reads=num_parallel_reads)
    if is_training:
        dataset = dataset.shuffle(buffer_size=1500)
    dataset = dataset.map(parse_record)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(num_epochs)

    iterator = dataset.make_one_shot_iterator()

    features, labels = iterator.get_next()

    return features, labels

对于训练集,“batch_size”设置为 128,“num_epochs”设置为 None,这意味着无限重复。对于验证集,“batch_size”设置为 32(后来设置为 2,仍然无效),“num_epochs”设置为 1,因为我只想通过验证集一次。我可以保证验证集包含足够的时代数据。因为我已经尝试了下面的代码并且它没有引发任何错误:

with tf.Session() as sess:
    features, labels = input_fn(False, valid_list, 32, 1, 1)
    for i in range(450):
        sess.run([features, labels])
        print(labels.shape)

在上面的代码中,当我将数字 450 更改为 500 或更大时,它会引发“OutOfRangeError”。这可以确认我的验证数据集包含足够的数据,可用于 450 次迭代,批量大小为 32。

我尝试对验证集使用较小的批量大小(即 2),但仍然有相同的错误。我可以在 input_fn 中将“num_epochs”设置为“None”的情况下运行代码以进行验证,但这似乎不是验证的工作方式。请问有什么帮助吗?

标签: python-3.xtensorflowtensorflow-datasets

解决方案


这种行为是正常的。来自 Tensorflow 文档:

如果迭代器到达数据集的末尾,则执行Iterator.get_next()操作将引发tf.errors.OutOfRangeError. 此后,迭代器将处于不可用状态,如果您想进一步使用它,则必须再次对其进行初始化。

设置时没有引发错误的原因dataset.repeat(None)是因为数据集永远不会耗尽,因为它会无限重复。

要解决您的问题,您应该将代码更改为:

n_steps = 450
...    

with tf.Session() as sess:
    # Training
    features, labels = input_fn(True, training_list, 32, 1, 1)

    for step in range(n_steps):
        sess.run([features, labels])
        ...
    ...
    # Validation
    features, labels = input_fn(False, valid_list, 32, 1, 1)
    try:
        sess.run([features, labels])
        ...
    except tf.errors.OutOfRangeError:
        print("End of dataset")  # ==> "End of dataset"

您还可以对 input_fn 进行一些更改以在每个时期运行评估:

def input_fn(is_training, filenames, batch_size, num_epochs=1, num_parallel_reads=1):
    dataset = tf.data.TFRecordDataset(filenames,num_parallel_reads=num_parallel_reads)
    if is_training:
        dataset = dataset.shuffle(buffer_size=1500)
    dataset = dataset.map(parse_record)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(num_epochs)

    iterator = dataset.make_initializable_iterator()
    return iterator

n_epochs = 10
freq_eval = 1

training_iterator = input_fn(True, training_list, 32, 1, 1)
training_features, training_labels = training_iterator.get_next()

val_iterator = input_fn(False, valid_list, 32, 1, 1)
val_features, val_labels = val_iterator.get_next()

with tf.Session() as sess:
    # Training
    sess.run(training_iterator.initializer)
    for epoch in range(n_epochs):
        try:
            sess.run([training_features, training_labels])
        except tf.errors.OutOfRangeError:
            pass

        # Validation
        if (epoch+1) % freq_eval == 0:
            sess.run(val_iterator.initializer)
            try:
                sess.run([val_features, val_labels])
            except tf.errors.OutOfRangeError:
                pass

如果您想更好地了解幕后发生的事情,我建议您仔细查看此官方指南。


推荐阅读