python - tf.string_input_producer 即使在将 epoch 设置为大于 1 后也会给出单个 epoch
问题描述
我的输入 api 使用 tf.string_input_producer 和 tf.parse_single_sequence_example。当我在 tf.string_input_producer 中设置 num_epochs > 1 时,我的队列仍然会在一个 epoch 之后完成。
这是预期的行为还是我做错了什么?以下是相关代码:
class TFRecordReader():
def __init__(self):
#some code....
def execute_queue(self, tensor_queue, exception_message: str, log_dir_path=None):
import os
if log_dir_path is None:
path = os.path.abspath('../../audio_log_dir/')
else:
path = log_dir_path
writer = self._summary_file_writer(path)
coord, thread = self._coord_thread()
print('should_stop: ', coord.should_stop())
if not coord.should_stop():
try:
if self._data_v is None:
self._data_v = self._parse_tensor(tensor_queue)
return self._data_v
except self.tf.errors.OutOfRangeError:
print(exception_message)
finally:
coord.request_stop()
coord.join(thread)
writer.close()
def single_sequence_batch(self,
tf_record_path,
feature_map,
parse_function,
num_epochs=None,
tf_record_compression=None,
queue_completion_message='Data Exhausted!',
log_dir_path=None
):
self.feature_map = feature_map
self.parse_func = parse_function
batch = self._single_sequence_batch(tf_record_path=tf_record_path,
num_epochs=num_epochs,
tf_record_compression=tf_record_compression)
data_queue = self.execute_queue(batch, queue_completion_message, log_dir_path=log_dir_path)
return data_queue
def _test_single_sequence_batch(num_epochs=1):
tfr_path = r'C:/audio_tfrecord/audioapi.tfrecord'
reader = TFRecordReader()
data_q = reader.single_sequence_batch(tf_record_path=tfr_path,
feature_map=feature_mapping,
parse_function=parse_func,
tf_record_compression=True,
num_epochs=num_epochs)
print(len(data_q))
c = 0
try:
for i in range(num_epochs):
val = reader.session.run(data_q)
print(val)
c += 1
except tf.errors.OutOfRangeError:
print("Total Examples :", c)
print('Finished!')
解决方案
推荐阅读
- google-cloud-platform - gcloud auth activate-service-account 注销/撤销/删除/取消设置
- javascript - 如何将输入滚动到视图中但顶部有一些边距
- javascript - 从 JavaScript 访问 Java Servlet
- r - R Shiny:varSelectInput 仅来自数据帧的数字变量
- graphql - Appsync 如何将请求标头传递给子解析器?
- c# - 在运行时动态创建按钮时,使用 Interaction.Triggers 将按钮单击事件绑定到 ViewModel 中的方法
- sql - 如何在oracle中为每个员工逐列比较两个员工表
- cuda - 关于 Cuda 1D 卷积,我怎样才能更快地做到这一点?
- r - 找出一个参数的平均值是多少个标准偏差从 0 R
- python - 事务请求不能包含对一项python的多个操作