python - 了解 NVLabs/noise2noise 中的 tf.data 管道
问题描述
我正在尝试重新实现Nvidia的noise2noise repo的某些部分来学习张量流和tf.data
管道,但我在理解正在发生的事情时遇到了很多麻烦。到目前为止,我能够创建一个TFRecord
由https://github.com/NVlabs/noise2noise/blob/master/dataset_tool_tf.pytf.train.Example
中描述的类型组成的
image = load_image(imgname)
feature = {
'shape': shape_feature(image.shape),
'data': bytes_feature(tf.compat.as_bytes(image.tostring()))
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
那部分是有道理的。让我发疯的是https://github.com/NVlabs/noise2noise/blob/master/dataset.py中的噪声增强部分,特别是函数:
def create_dataset(train_tfrecords, minibatch_size, add_noise):
print ('Setting up dataset source from', train_tfrecords)
buffer_mb = 256
num_threads = 2
dset = tf.data.TFRecordDataset(train_tfrecords, compression_type='', buffer_size=buffer_mb<<20)
dset = dset.repeat()
buf_size = 1000
dset = dset.prefetch(buf_size)
dset = dset.map(parse_tfrecord_tf, num_parallel_calls=num_threads)
dset = dset.shuffle(buffer_size=buf_size)
dset = dset.map(lambda x: random_crop_noised_clean(x, add_noise))
dset = dset.batch(minibatch_size)
it = dset.make_one_shot_iterator()
return it
返回一个迭代器。此迭代器用于train.py
并具有在每次迭代时返回的三个元素:
noisy_input, noisy_target, clean_target = dataset_iter.get_next()
我已经尝试在本地 tensorflow jupyter notebook 中重新实现它,但我无法弄清楚这三个项目的来源。按照我的理解,该create_dataset(...)
函数只获取Example
记录中的每个输入图像,并用高斯/泊松噪声对其进行增强。但是为什么返回的迭代器指向三个不同的图像呢?迭代器中的增强create_dataset(...)
和三个不同的图像之间有什么联系?
map
我发现了这一点,这对理解、batch
和非常有帮助shuffle
:批处理、重复和随机播放对 TensorFlow 数据集有什么作用?
解决方案
推荐阅读
- python - 使用 lamda 更改全局变量的值?
- node.js - AWS lambda 向电话号码发送 sns 消息
- python - 在某些条件下将一到三个之间的随机值添加到 DataFrame 的列中
- python - 如何使用 Selenium driver.find_elements_by_xpath 提取位置的名称。没有返回数据。试图打印出 Tit Heng Phone Shop
- c# - 当 signalR 连接丢失到集线器代理时,始终保持重新连接状态 2 分钟
- sql - 将行值转换为列名
- amazon-web-services - AWS API Gateway 不接收发布请求
- spring-boot - 在特定时间窗口中每 30 分钟运行一次的作业的 Cron 表达式
- shopify - 更改主题中的 shopify 标签
- java - 您应该在 `onBillingServiceDisconnected()` 中添加什么?