首页 > 解决方案 > 使用 tf.data.Dataset() 对映射进行矢量化时出错

问题描述

我有一个图像数据集,我通过tf.data.Dataset.list_files().

在我的.map()函数中,我读取和解码图像,如下所示:

def map_function(filepath):
    image = tf.io.read_file(filename=filepath)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, [IMAGE_WIDTH, IMAGE_HEIGHT])
    return image

如果我使用(下面的工作)

 dataset = tf.data.Dataset.list_files(file_pattern=...)
 dataset = dataset.map(map_function)
 for image in dataset.as_numpy_iterator():
    #Correctly outputs the numpy array, no error is displayed/encountered
    print(image)

但是,如果我使用(下面会抛出错误):

  dataset = tf.data.Dataset.list_files(file_pattern=...)
  dataset = dataset.batch(32).map(map_function)
  for image in dataset.as_numpy_iterator():
    #Error is displayed 
      print(image)

ValueError:形状必须为 0 级,但对于具有输入形状的“ReadFile”(操作:“ReadFile”)为 1 级:[?]。

现在,根据这个:https://www.tensorflow.org/guide/data_performance#vectorizing_mapping,代码不应该失败并且预处理步骤应该被优化(批处理与一次性处理)。

我的代码中的错误在哪里?

***如果我使用map().batch()它工作正常

标签: pythontensorflowtensorflow2.0tensorflow-datasetstensorflow2.x

解决方案


发生错误是因为map_function需要未批处理的元素,但在第二个示例中,您为其提供了批处理元素。

https://www.tensorflow.org/guide/data_performance中的示例通过定义一个increment可以应用于批处理和非批处理元素的函数而变得棘手,因为将 1 添加到像 [1, 2, 3] 这样的批处理元素将导致在 [2, 3, 4] 中。

def increment(x):
    return x+1

要使用向量化,您需要编写 a vectorized_map_function,它接受一个未批处理元素的向量,将 map 函数应用于向量中的每个元素,然后返回结果向量。

不过,在您的情况下,我认为矢量化不会产生明显的影响,因为读取和解码文件的成本远高于调用函数的开销。当 map 函数非常便宜时,矢量化的影响最大,以至于函数调用所花费的时间与在 map 函数中实际工作所花费的时间相当。


推荐阅读