首页 > 解决方案 > 如何使用提供的代码 Tensorflow 在迭代期间仅使用一半的 imagenet 训练集?

问题描述

代码由 Tensorflow 提供,这是在 trianing 时获取 ImageNet TFrecord 文件的方法:

import tensorflow as tf       
import imagenet_data
import image_processing

imagenet_data_train = imagenet_data.ImagenetData('train')
train_images, train_labels =  image_processing.inputs(imagenet_data_train, batch_size=256, num_preprocess_threads=16)

coord = tf.train.Coordinator()
threads = []
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
   threads.extend(qr.create_threads(sess, coord=coord, daemon=True, start=True))


with tf.Session() as sess:      

    sess.run(tf.global_variables_initializer())

    try:
        for i in range(1000):    
            image_batch, label_batch = sess.run([train_images, train_labels ])

    finally:
            coord.request_stop()
            coord.join(threads)

而现在我只想使用一半的训练数据(可能是 Tfreocd 文件中的前 60 万条数据)在训练期间进行迭代,我应该设置什么?

标签: pythontensorflowimagenet

解决方案


我通过修改类 ImagenetDatatrain_set_number_rate并添加一个额外的参数train_set_number_rate并修改方法来解决它,该方法data_files控制要传递给下一个函数的文件名列表。

请注意,通过这种方式,我必须使用函数distorted_inputs而不是inputs确保在不同模型中使用训练集的相同部分进行微调。(这可能会对 vali 性能造成一些不良影响。但由于我曾经inputs训练过原始模型网络,inputs必须用于确保比较的正确性。)

 class ImagenetData(Dataset):

     def __init__(self, subset, train_set_number_rate = None):
        super(ImagenetData, self).__init__('ImageNet', subset)
        self.train_set_number_rate = train_set_number_rate
     def data_files(self):
        """Returns a python list of all (sharded) data subset files.

        Returns:
          python list of all (sharded) data set files.
        Raises:
          ValueError: if there are not data_files matching the subset.
        """
        tf_record_pattern = os.path.join(data_dir, '%s-*' % self.subset)
        data_files = tf.gfile.Glob(tf_record_pattern)
        if self.subset=='validation':
            assert(self.train_set_number_rate==None)
        elif self.subset=='train':
            if self.train_set_number_rate!=None:
                data_files = data_files[0:round(len(data_files)*self.train_set_number_rate)]
        if not data_files:
            print('No files found for dataset %s/%s at %s' % (self.name,
                                                              self.subset,
                                                              data_dir))

            self.download_message()
            exit(-1)
        return data_files

推荐阅读