python - 在受监控的训练会话中数据集迭代器初始化的 TensorFlow 错误
问题描述
大家好,我需要一些帮助。我尝试在不使用估计器的情况下使用 tensorflow 对 resnet-101 imagenet 分类进行编码。我尝试用它来研究深度学习并了解如何使用 tensorflow。我的问题是 monitortrainingSession 没有初始化我的迭代器。
我已经阅读了一些关于这些问题的文章并尝试使用钩子来处理它,但它失败了,我不知道它为什么会失败。
在我创建 monitortrainingsession 之后,它首先初始化 train_iterator 并得到 outOfRange 异常,然后执行验证步骤。
现在看起来很好,但在完成运行验证并尝试再次运行训练步骤后。我得到了与 iterator.get_next() 相关的错误。它说我没有初始化迭代器,但我的钩子函数清楚地调用了
session.run(self._initializer, feed_dict={filenames: self._filenames})
我确定,因为我可以看到我打印的以下消息以检查它是否已初始化。
iter_val.initializer after_create_session is called 0 times
我有什么问题?
运行流程如下
运行训练步骤很好(epoch = 0)
运行验证步骤正常(epoch = 0)
运行训练步骤错误(epoch = 1)
请忽略代码中的 horovod(hvd()) ,因为我现在没有使用它。
这是我的代码,所以请帮我修复它,让我知道我的代码有什么问题。
class _DatasetInitializerHook(tf.train.SessionRunHook):
def __init__(self, initializer, filenames=[], name=""):
self._initializer = initializer
self._filenames = filenames
self._name = name
self._cnt = 0
self._before_runCnt = 0
def begin(self):
pass
def after_create_session(self, session, coord):
del coord
if len(self._filenames) == 0:
session.run(self._initializer)
else:
session.run(self._initializer, feed_dict={filenames: self._filenames})
print(self._name, "after_create_session is called {} times".format(self._cnt))
self._cnt += 1
if __name__ == "__main__":
if len(sys.argv) > 1:
nlogs = sys.argv[1]
else:
nlogs = 0
hvd.init()
b_imagenet=False
if b_imagenet:
training_filenames = ['/data/tfrecords/imagenet2012_train_shard{}.tfrecord'.format(i) for i in range(129)]
else:
training_filenames = ['/data/cifar-10-tfrecords/train_shard{}.tfrecord'.format(i) for i in range(1, 2, 1)]
filenames = tf.placeholder(tf.string, shape=[None])
trainData = dataset_input_fn(is_training=True, filename=filenames, nworkers=hvd.size(), workeridx=hvd.rank(),
batch_size=FLAGS.batchSize, prefetch_size=FLAGS.prefetch_buffer_size, repeat=1,
shuffle_buffer_size=FLAGS.shuffle_buffer_size)
valData = dataset_input_fn(is_training=False, filename=FLAGS.validationfile, nworkers=hvd.size(), workeridx=hvd.rank(),
batch_size=1,prefetch_size=FLAGS.prefetch_buffer_size, repeat=1, shuffle_buffer_size=1)
# Pin GPU to be used to process local rank (one GPU per process)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = str(hvd.local_rank())
for i in tqdm(range(FLAGS.nepoch)):
shuffle(training_filenames)
model = model_class(nCls=FLAGS.nClasses, img_width=FLAGS.width, img_height=FLAGS.height,
learning_rate=FLAGS.learning_rate, weight_decay=FLAGS.weight_decay)
iter_train = trainData.make_initializable_iterator()
train_op = model.build_model(iter_train.get_next(), is_trainig=True, hvd=None)
train_hooks = [hvd.BroadcastGlobalVariablesHook(0),
_DatasetInitializerHook(iter_train.initializer, training_filenames, "iter_train.initializer")]
with tf.train.MonitoredTrainingSession(checkpoint_dir="./tmp/train_logs", config=config, hooks=train_hooks,
save_checkpoint_secs=30) as sess:
try:
while True:
opt = sess.run([train_op])
except tf.errors.OutOfRangeError:
pass
iter_val = valData.make_initializable_iterator()
prediction_result = model.build_model(iter_val.get_next(),is_trainig=False, hvd=None)
validation_hooks = [hvd.BroadcastGlobalVariablesHook(0),
_DatasetInitializerHook(iter_val.initializer, [], "iter_val.initializer")]
with tf.train.MonitoredTrainingSession( checkpoint_dir="./tmp/train_logs",config=config, hooks=validation_hooks) as sess:
try:
while True:
result = sess.run([prediction_result])
except tf.errors.OutOfRangeError:
pass
这是我收到的错误消息。
tensorflow.python.framework.errors_impl.FailedPreconditionError: GetNext() 失败,因为迭代器尚未初始化。确保在获取下一个元素之前已为此迭代器运行了初始化程序操作。[[节点IteratorGetNext(定义在workspace/multi_gpu/main.py:128)]]
错误可能源于输入操作。连接到节点IteratorGetNext的输入源操作:IteratorV2_2(定义在workspace/multi_gpu/main.py:126)
解决方案
尝试将初始化程序放入脚手架:
scaffold = tf.train.Scaffold(local_init_op=train_init_operator)
并将其交给monitoredTrainingSession
with:
with tf.train.MonitoredTrainingSession(scaffold=scaffold, ...
推荐阅读
- python - 无法相应地打印结果并将相同的结果写入 csv 文件
- c++ - 玩转 C++20 概念
- pandas - 矢量化熊猫申请 pd.date_range
- python - 如何安装最新版本的 TensoFlow 2?
- vue.js - vue + nuxt js - 如何在服务器端的插件中访问上下文?
- javascript - Javascript比较两个对象键生成新对象
- git - 使用 Ansible 的 git pull 命令
- php - 在 php 中通过邮件发送带有数据的 HTML 表
- python - 如何键入变量并将其读取为普通文本
- go - 你如何在 JS catch 中从 Go 服务器进行 fetch 调用?