首页 > 解决方案 > 了解 NVLabs/noise2noise 中的 tf.data 管道

问题描述

我正在尝试重新实现Nvidia的noise2noise repo的某些部分来学习张量流和tf.data管道,但我在理解正在发生的事情时遇到了很多麻烦。到目前为止,我能够创建一个TFRecordhttps://github.com/NVlabs/noise2noise/blob/master/dataset_tool_tf.pytf.train.Example中描述的类型组成的

 image = load_image(imgname)
    feature = {
      'shape': shape_feature(image.shape),
      'data': bytes_feature(tf.compat.as_bytes(image.tostring()))
    }
    example = tf.train.Example(features=tf.train.Features(feature=feature))
    writer.write(example.SerializeToString())

那部分是有道理的。让我发疯的是https://github.com/NVlabs/noise2noise/blob/master/dataset.py中的噪声增强部分,特别是函数:

def create_dataset(train_tfrecords, minibatch_size, add_noise):
    print ('Setting up dataset source from', train_tfrecords)
    buffer_mb   = 256
    num_threads = 2
    dset = tf.data.TFRecordDataset(train_tfrecords, compression_type='', buffer_size=buffer_mb<<20)
    dset = dset.repeat()
    buf_size = 1000
    dset = dset.prefetch(buf_size)
    dset = dset.map(parse_tfrecord_tf, num_parallel_calls=num_threads)
    dset = dset.shuffle(buffer_size=buf_size)
    dset = dset.map(lambda x: random_crop_noised_clean(x, add_noise))
    dset = dset.batch(minibatch_size)
    it = dset.make_one_shot_iterator()
    return it

返回一个迭代器。此迭代器用于train.py并具有在每次迭代时返回的三个元素:

    noisy_input, noisy_target, clean_target = dataset_iter.get_next()

我已经尝试在本地 tensorflow jupyter notebook 中重新实现它,但我无法弄清楚这三个项目的来源。按照我的理解,该create_dataset(...)函数只获取Example记录中的每个输入图像,并用高斯/泊松噪声对其进行增强。但是为什么返回的迭代器指向三个不同的图像呢?迭代器中的增强create_dataset(...)和三个不同的图像之间有什么联系?

map我发现了这一点,这对理解、batch和非常有帮助shuffle批处理、重复和随机播放对 TensorFlow 数据集有什么作用?

标签: pythontensorflow

解决方案


推荐阅读