首页 > 解决方案 > 无法从 tf.data.Dataset.from_generator 读取数据

问题描述

我希望使用 tf.data.Dataset.from_generator 来帮助我生成输入数据流。

dataset = tf.data.Dataset.from_generator(sample_generator,
                                         output_types=(tf.float32, tf.float32, tf.float32), 
                                         output_shapes=(tf.TensorShape([size_image, size_image, 3]),tf.TensorShape([size_image,size_image,3]), tf.TensorShape([size_gt, size_gt])))

sample_generator 函数可以生成三个具有建议形状的 numpy 数组。

上面提到的部分没有问题。但是,tf.data.Dataset.from_generator 只能生成数据流,我需要将生成的数据提供给网络。

代码如下:

dataset = dataset.map(transform_fn, num_parallel_calls=self.config['prefetch_threads']) # transform_fn just returns the input

dataset = dataset.prefetch(self.config['prefetch_capacity'])
dataset = dataset.repeat()
dataset = dataset.batch(self.config['batch_size'])

迭代器是

self.iterator = self.dataset_tf.make_one_shot_iterator()

self.iterator.get_next()

谢谢!

标签: pythontensorflowdeep-learning

解决方案


问题已解决。更改 from_generator 以生成元数据,例如,您需要的文件的名称/路径。

然后,使用 tf.Py_func 以普通/numpy 方式预处理数据。我只想说:“tf.Pyfunc 太方便了!” 谢谢!


推荐阅读