python-3.x - 使用 TensorFlow 训练 CNN 时如何修复“OutOfRangeError:序列结束”错误?
问题描述
我正在尝试使用我自己的数据集训练 CNN。我一直在使用 tfrecord 文件和 tf.data.TFRecordDataset API 来处理我的数据集。它适用于我的训练数据集。但是当我尝试批处理我的验证数据集时,出现了“OutOfRangeError:序列结束”的错误。上网浏览后,我认为问题是验证集的批大小引起的,我一开始设置为 32。但是在我将其更改为 2 之后,代码运行了 9 个 epoch,并且再次引发了错误。
我使用输入函数来处理数据集,代码如下:
def input_fn(is_training, filenames, batch_size, num_epochs=1, num_parallel_reads=1):
dataset = tf.data.TFRecordDataset(filenames,num_parallel_reads=num_parallel_reads)
if is_training:
dataset = dataset.shuffle(buffer_size=1500)
dataset = dataset.map(parse_record)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
对于训练集,“batch_size”设置为 128,“num_epochs”设置为 None,这意味着无限重复。对于验证集,“batch_size”设置为 32(后来设置为 2,仍然无效),“num_epochs”设置为 1,因为我只想通过验证集一次。我可以保证验证集包含足够的时代数据。因为我已经尝试了下面的代码并且它没有引发任何错误:
with tf.Session() as sess:
features, labels = input_fn(False, valid_list, 32, 1, 1)
for i in range(450):
sess.run([features, labels])
print(labels.shape)
在上面的代码中,当我将数字 450 更改为 500 或更大时,它会引发“OutOfRangeError”。这可以确认我的验证数据集包含足够的数据,可用于 450 次迭代,批量大小为 32。
我尝试对验证集使用较小的批量大小(即 2),但仍然有相同的错误。我可以在 input_fn 中将“num_epochs”设置为“None”的情况下运行代码以进行验证,但这似乎不是验证的工作方式。请问有什么帮助吗?
解决方案
这种行为是正常的。来自 Tensorflow 文档:
如果迭代器到达数据集的末尾,则执行
Iterator.get_next()
操作将引发tf.errors.OutOfRangeError
. 此后,迭代器将处于不可用状态,如果您想进一步使用它,则必须再次对其进行初始化。
设置时没有引发错误的原因dataset.repeat(None)
是因为数据集永远不会耗尽,因为它会无限重复。
要解决您的问题,您应该将代码更改为:
n_steps = 450
...
with tf.Session() as sess:
# Training
features, labels = input_fn(True, training_list, 32, 1, 1)
for step in range(n_steps):
sess.run([features, labels])
...
...
# Validation
features, labels = input_fn(False, valid_list, 32, 1, 1)
try:
sess.run([features, labels])
...
except tf.errors.OutOfRangeError:
print("End of dataset") # ==> "End of dataset"
您还可以对 input_fn 进行一些更改以在每个时期运行评估:
def input_fn(is_training, filenames, batch_size, num_epochs=1, num_parallel_reads=1):
dataset = tf.data.TFRecordDataset(filenames,num_parallel_reads=num_parallel_reads)
if is_training:
dataset = dataset.shuffle(buffer_size=1500)
dataset = dataset.map(parse_record)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_initializable_iterator()
return iterator
n_epochs = 10
freq_eval = 1
training_iterator = input_fn(True, training_list, 32, 1, 1)
training_features, training_labels = training_iterator.get_next()
val_iterator = input_fn(False, valid_list, 32, 1, 1)
val_features, val_labels = val_iterator.get_next()
with tf.Session() as sess:
# Training
sess.run(training_iterator.initializer)
for epoch in range(n_epochs):
try:
sess.run([training_features, training_labels])
except tf.errors.OutOfRangeError:
pass
# Validation
if (epoch+1) % freq_eval == 0:
sess.run(val_iterator.initializer)
try:
sess.run([val_features, val_labels])
except tf.errors.OutOfRangeError:
pass
如果您想更好地了解幕后发生的事情,我建议您仔细查看此官方指南。
推荐阅读
- c# - JObject.Parse 上的额外对象包装器
- php - 在嵌套数组中循环键值以在 PHP 中发出 SOAP 请求
- c++ - 生成带有顶点和索引的球体?
- docker - AWS ECS(EC2启动类型),Nest js中的docker容器TCP通信
- c++ - 在类中初始化 unique_ptr
- jupyter-notebook - 如何使用此 python 代码解决关键错误?
- amazon-web-services - 如果我不添加默认变量,则 Terraform/Terragrunt 错误
- python - Django 3:无法生成动态对象视图
- html - 动画旋转框 css 关键帧
- node.js - 如何在 MongoDB 中相互比较文档?