首页 > 解决方案 > Tensorflow 数据集 API 获取大批量数据的速度很慢

问题描述

我发现即使所有数据都在内存中,当批量很大时,从 tensorflow 数据集 API 获取一批可能会非常慢。以下是一个例子。有没有人有任何见识?

FEATURE_NUM = 500
tf_X = tf.placeholder(dtype=tf.float32, shape=[None, FEATURE_NUM], name="X")
tf_Y = tf.placeholder(dtype=tf.float32, shape=[None, 1], name="Y")

batch_size = 1000000
dataset = tf.data.Dataset.from_tensor_slices((tf_X, tf_Y)).batch(batch_size)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

se = tf.Session()
se.run(tf.global_variables_initializer())
se.run(iterator.initializer, feed_dict={tf_X : numpy_array_X, tf_Y : numpy_array_Y})

while True:
    data = se.run(next_element) # This takes more than 5 seconds per call

标签: performancetensorflow

解决方案


推荐阅读