首页 > 解决方案 > Tensorflow 将数据集拆分为训练和测试导致瓶颈/缓慢

问题描述

我有一个数据集,当我用ds = ds.map(process_path, num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)那条线对其进行预处理时,它的执行速度非常快。然后,当我尝试使用以下方法访问数据集的元素之一时:

for image, label in ds.take(1):
  print(image.shape)
  image = tf.squeeze(image)
  plt.imshow(image, cmap='gray')

加载需要一两秒钟;这是我的第一个问题:

预处理是否仅在访问数据集中的元素时才在数据集上运行,而不是在我调用 ds.map(process_path,...) 时立即运行?

然而,我的主要问题是,当我将数据集ds分为训练和测试两部分并尝试再次访问其中一个元素时,速度相当慢......就像慢了 20 倍。我把它分成两部分:

test_ds_size = int(image_count * 0.2)
train_ds = ds.skip(test_ds_size)
test_ds = ds.take(test_ds_size)

然后我尝试以与上述相同的方式访问它,但替换dstrain_ds; 我的第二个问题是:

为什么这要慢得多,只是将它分成两部分?

还是我做错了什么...

任何帮助是极大的赞赏。

标签: pythontensorflowmachine-learningtensorflow2.0

解决方案


dataset.map通过应用 map 函数创建一个新的数据集。
即使在循环中,当您执行dataset.take()它时,它也会以非常短的时间从指定的数字创建一个新数据集。
加载数据集后,您正在执行与性能无关的其他操作tf.data
您可以从以下示例中进行检查。

import tensorflow as tf
from time import time

dataset = tf.data.Dataset.range(1, 100)
t1 = time()
dataset = dataset.map(lambda x: x + 1)
t2 = time()
print("Time taken for map : ", t2-t1)

t3 = time()
ds = dataset.take(50)
t4 = time()
list(ds.as_numpy_iterator())
print("Time taken for take() : ",t4-t3) 

Time taken for map :  0.013489961624145508
Time taken for take() :  0.0005645751953125

现在,让我们看看 take() 在一些操作之后所花费的时间。

dataset = tf.data.Dataset.range(1, 100)
t1 = time()
dataset = dataset.map(lambda x: x + 1)
t2 = time()
print("Time taken for map : ", t2-t1)

t3 = time()
ds = dataset.take(50)
list(ds.as_numpy_iterator())
t4 = time()
print("Time taken for take() after some operation : ",t4-t3)

Time taken for map :  0.00974416732788086
Time taken for take() after some operation :  0.017722606658935547 

可以按照您指定的方式从现有数据集中拆分训练数据和测试数据,但这需要时间,因为它会遍历所有元素。

tf.data.Dataset为训练和测试创建的理想方法是分别创建它,如下所示。确保在训练和测试数据中正确分布数据集之前对数据进行混洗。


推荐阅读