python - 使用 tf.data.Dataset() 对映射进行矢量化时出错
问题描述
我有一个图像数据集,我通过tf.data.Dataset.list_files()
.
在我的.map()
函数中,我读取和解码图像,如下所示:
def map_function(filepath):
image = tf.io.read_file(filename=filepath)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [IMAGE_WIDTH, IMAGE_HEIGHT])
return image
如果我使用(下面的工作)
dataset = tf.data.Dataset.list_files(file_pattern=...)
dataset = dataset.map(map_function)
for image in dataset.as_numpy_iterator():
#Correctly outputs the numpy array, no error is displayed/encountered
print(image)
但是,如果我使用(下面会抛出错误):
dataset = tf.data.Dataset.list_files(file_pattern=...)
dataset = dataset.batch(32).map(map_function)
for image in dataset.as_numpy_iterator():
#Error is displayed
print(image)
ValueError:形状必须为 0 级,但对于具有输入形状的“ReadFile”(操作:“ReadFile”)为 1 级:[?]。
现在,根据这个:https://www.tensorflow.org/guide/data_performance#vectorizing_mapping,代码不应该失败并且预处理步骤应该被优化(批处理与一次性处理)。
我的代码中的错误在哪里?
***如果我使用map().batch()
它工作正常
解决方案
发生错误是因为map_function
需要未批处理的元素,但在第二个示例中,您为其提供了批处理元素。
https://www.tensorflow.org/guide/data_performance中的示例通过定义一个increment
可以应用于批处理和非批处理元素的函数而变得棘手,因为将 1 添加到像 [1, 2, 3] 这样的批处理元素将导致在 [2, 3, 4] 中。
def increment(x):
return x+1
要使用向量化,您需要编写 a vectorized_map_function
,它接受一个未批处理元素的向量,将 map 函数应用于向量中的每个元素,然后返回结果向量。
不过,在您的情况下,我认为矢量化不会产生明显的影响,因为读取和解码文件的成本远高于调用函数的开销。当 map 函数非常便宜时,矢量化的影响最大,以至于函数调用所花费的时间与在 map 函数中实际工作所花费的时间相当。
推荐阅读
- c++ - shared_ptr 对象上的 dynamic_cast
- magento2 - Magento 2 的全新安装 - 商店加载正常,管理员登录失败并出现 CURLPROTO_HTTP 错误
- typo3 - 使用flux:field.inline.fal时如何获取fal对象而不是数组
- javascript - jquery .val() 方法在动态生成的元素上获取请求中
- javascript - 助推器。导航栏问题,不要停留在顶部
- java - Intellij 没有看到我的新 JAR
- .htaccess - 来自子域的 CORS 屏障
- django - Django Admin - 登录时出现 404 错误
- sql - 如何按以下要求从 SQL Server 检索数据?
- python-3.x - 分离 Python3 依赖项