首页 > 解决方案 > 使用 Tensorflow 进行训练时如何正确实施验证?

问题描述

那里。我想在训练过程中验证我的神经网络。创建一个占位符来决定是否应该将 train_batch 或 test_batch 输入网络以实现验证。这是我的示例代码。在train.tfrecords中,数字从 0 到 100,在test.tfrecords中,从 200 到 240。

create_tfrecord(filename=r'test', begin_num=200, end_num=240)

filename = r'./train.tfrecords'
filename_que = tf.train.string_input_producer([filename],
                                              num_epochs=1)
num = decode_tfrecord(filename_que)

testfilename = r'./test.tfrecords'
test_que = tf.train.string_input_producer([testfilename],
                                          num_epochs=None)
test_num = decode_tfrecord(test_que)

is_training = tf.placeholder(dtype=tf.bool, shape=())

num = tf.cond(is_training, lambda:num, lambda:test_num)    

coord = tf.train.Coordinator()

init_op = tf.local_variables_initializer()
with tf.Session() as sess:

    sess.run(init_op)
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    try:
        i = 0
        while not coord.should_stop():

            feed_dict = {is_training: True}
            number = sess.run(num, feed_dict=feed_dict)
            print("global_step: %i, read_num: %i" %(i, number))

            if i % 5 == 0 and i != 0:
                feed_dict = {is_training: False}
                number = sess.run(num, feed_dict=feed_dict)
                print("global_step: %i, read_num: %i" %(i, number))

            i += 1


    except tf.errors.OutOfRangeError:
        print("done.")

    finally:
        coord.request_stop()
        coord.join(threads)

但是,输出与我想要的并不完全相同。test_batch 不是从 test.tfrecords 中的第一个数字 200 开始,而是从 206 开始。之后, train_batch从 7 而不是 6 开始。这是一些输出。

global_step: 0, read_num: 0
global_step: 1, read_num: 1
global_step: 2, read_num: 2
global_step: 3, read_num: 3
global_step: 4, read_num: 4
global_step: 5, read_num: 5
global_step: 5, read_num: 206
global_step: 6, read_num: 7
global_step: 7, read_num: 8
global_step: 8, read_num: 9
global_step: 9, read_num: 10
global_step: 10, read_num: 11
global_step: 10, read_num: 212

这些输出的行为就像存在一个用于训练和验证的计数器。一旦 epochs 设置为 1,并不是我所有的训练集都可以在训练阶段使用,并且我的 test_batch 获取的顺序不是我想要的。我认为这种方式是在训练时实现验证的错误方式。

那么,有没有人可以告诉我如何在训练我的网络期间正确实施验证?

提前致谢。

标签: pythontensorflow

解决方案


推荐阅读