python - 如何使用提供的代码 Tensorflow 在迭代期间仅使用一半的 imagenet 训练集?
问题描述
代码由 Tensorflow 提供,这是在 trianing 时获取 ImageNet TFrecord 文件的方法:
import tensorflow as tf
import imagenet_data
import image_processing
imagenet_data_train = imagenet_data.ImagenetData('train')
train_images, train_labels = image_processing.inputs(imagenet_data_train, batch_size=256, num_preprocess_threads=16)
coord = tf.train.Coordinator()
threads = []
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
threads.extend(qr.create_threads(sess, coord=coord, daemon=True, start=True))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
try:
for i in range(1000):
image_batch, label_batch = sess.run([train_images, train_labels ])
finally:
coord.request_stop()
coord.join(threads)
而现在我只想使用一半的训练数据(可能是 Tfreocd 文件中的前 60 万条数据)在训练期间进行迭代,我应该设置什么?
解决方案
我通过修改类 ImagenetDatatrain_set_number_rate
并添加一个额外的参数train_set_number_rate
并修改方法来解决它,该方法data_files
控制要传递给下一个函数的文件名列表。
请注意,通过这种方式,我必须使用函数distorted_inputs
而不是inputs
确保在不同模型中使用训练集的相同部分进行微调。(这可能会对 vali 性能造成一些不良影响。但由于我曾经inputs
训练过原始模型网络,inputs
必须用于确保比较的正确性。)
class ImagenetData(Dataset):
def __init__(self, subset, train_set_number_rate = None):
super(ImagenetData, self).__init__('ImageNet', subset)
self.train_set_number_rate = train_set_number_rate
def data_files(self):
"""Returns a python list of all (sharded) data subset files.
Returns:
python list of all (sharded) data set files.
Raises:
ValueError: if there are not data_files matching the subset.
"""
tf_record_pattern = os.path.join(data_dir, '%s-*' % self.subset)
data_files = tf.gfile.Glob(tf_record_pattern)
if self.subset=='validation':
assert(self.train_set_number_rate==None)
elif self.subset=='train':
if self.train_set_number_rate!=None:
data_files = data_files[0:round(len(data_files)*self.train_set_number_rate)]
if not data_files:
print('No files found for dataset %s/%s at %s' % (self.name,
self.subset,
data_dir))
self.download_message()
exit(-1)
return data_files
推荐阅读
- gtsummary - 在 gtsummary 中显示和比较连续变量的正态分布与非正态分布的简单方法
- azure - Power shell 命令禁用 Azure 数据工厂的网络设置
- angular - 关闭模式或背景点击后如何防止页面滚动到顶部?
- python - 如何检测时间序列中的翻转?
- css - 我可以更改css中元素的父级吗?
- php - 不允许在购物车中添加特定的产品数量
- hash - 哈希表函数 hlist_add_before 在 Linxu 内核中的实现
- microsoft-graph-api - 将 Graph API 端点轮询到 Teams 状态信息
- python - 将 mako 渲染模板写入文件时创建的附加回车
- javascript - 如何根据属性值比较和过滤数组对象?