python - 来自生成器的数据集,一次产生多个元素
问题描述
如果是时候从已弃用的基于队列的 API 迁移到 TensorFlow 中的数据集 API,我正在试水。
我似乎找不到等效的一个用例enqueue_many
是tf.train.batch
.
特别是我想创建一个可以产生“批处理”数组的 Python 生成器,其中“批处理大小”不一定与用于 SGD 训练更新的那个相同,然后对该数据流应用批处理(即与 tf.train.batch 中的 enqueue_many 一起使用)。
在新的数据集 API 中是否有任何解决方法来实现这一点?
解决方案
Try using flatmap
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
n_reads=10
read_batch_size=20
training_batch_size = 2
def mnist_gen():
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
for i in range(n_reads):
batch_x, batch_y = mnist.train.next_batch(read_batch_size)
# Yielding a batch instead of single record
yield batch_x,batch_y
data = tf.data.Dataset.from_generator(mnist_gen,output_types=(tf.float32,tf.float32))
data = data.flat_map(lambda *x: tf.data.Dataset.zip(tuple(map(tf.data.Dataset.from_tensor_slices,x)))).batch(training_batch_size)
# if u yield only batch_x change lambda function to data.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x)))
iter = data.make_one_shot_iterator()
next_item = iter.get_next()
X= next_item[0]
Y = next_item[1]
with tf.Session() as sess:
for i in range(n_reads*read_batch_size // training_batch_size):
print(i, sess.run(X))
推荐阅读
- python - (再次)在 Windows 上访问 Python 中的长路径
- javascript - 如何在地图中为我的密钥道具生成唯一密钥?反应
- laravel - Laravel,自动生成新链接
- rust - 我可以运行夜间和稳定的编译器吗?
- firebase - 如何阻止从控制台更改 Firebase 实时数据库
- sharepoint - 以“归档格式”从 OneNote 文档中提取文本
- c++ - 无效指针的条件
- python - 如何从索引中的时间戳获取一天的秒数?
- javascript - 如何将 MongoDB ObjectID 作为十六进制字符串获取到渲染器中
- javascript - formControl 的值未显示在视图中