python - 如何正确使用 tf.function 与 TensorFlow 数据集
问题描述
我正在尝试使用带有 @tf.function 的 TF 数据集对图像目录执行一些预处理。在tf函数内部,图像文件被读取为 RAW 字符串张量,我试图从该张量中取出一个切片。切片(前 13 个字符)表示有关 .ppm 图像(标题)的信息。我收到一个错误:ValueError: Shape must be rank 1 but is rank 0 for 'Slice' (op: 'Slice') with input shapes: [], [1], [1]
。最初我试图直接对张量的 .numpy() 属性进行切片(tffilepath
函数的输入参数),但我认为在tf函数中执行此操作在语义上是错误的。它也不起作用,因为输入张量没有 numpy() 属性(我不明白为什么??)。在tf之外filepath
函数,例如在 jupyter 笔记本单元格中,我可以遍历数据集并获取具有 numpy 属性的单个项目,并对其进行切片和所有后续处理。我确实意识到我对 TF 如何工作的理解可能存在差距(我使用的是 TF 2.0),所以我希望有人能澄清我在阅读中遗漏的内容。tf函数的目的是将 ppm 图像转换为 png,所以这个函数有一个副作用,但我没有走得那么远来找出是否可以这样做。
这是代码:
@tf.function
def ppm_to_png(filepath):
ppm_bytes = tf.io.read_file(filepath) #.numpy()
bytes_header = tf.slice(ppm_bytes, [0], [13])
# bytes_header = ppm_bytes[:13].eval() # this did not work either with similar error msg
.
.
.
import glob
files = glob.glob(os.path.join(data_dir, '00000/*.ppm'))
dataset = tf.data.Dataset.from_tensor_slices(files)
png_filepaths = dataset.map(ppm_to_png, num_parallel_calls=tf.data.experimental.AUTOTUNE)
解决方案
要在 TF 中操作字符串值,请查看tf.strings 命名空间。
在这种情况下,您可以使用tf.strings.substr
:
@tf.function
def ppm_to_png(filepath):
ppm_bytes = tf.io.read_file(filepath)
bytes_header = tf.strings.substr(ppm_bytes, 0, 13)
tf.print(bytes_header)
tf.slice
只对张量对象起作用,对它们的元素不起作用。这里,ppm_bytes
是一个标量张量,包含一个类型为 的元素tf.string
,其值是文件的整个字符串内容。因此,当您调用 时tf.slice
,它仅查看标量位,并且不够聪明,无法意识到您实际上想要取该元素的一部分。
推荐阅读
- python - 无法单击字段集中的单选按钮
- vue.js - 使用 vuex-module-decorators 时组件的属性未更新
- excel - 如果基于复选框不存在,则在列表框中添加项目
- php - 如何在 PHP 中处理多个查询,包括多个 TEMPORARY 表
- linux - 在 docker 容器中创建应用程序(对任务顺序感到困惑)
- php - 如何在 PHP 中从 mongodb/driver/manager 获取集合类对象
- rabbitmq - 当响应来自使用 Spring Integration DSL 的 rabbitMQ 回复队列时,如何实现 HTTP 请求/回复?
- html - HTML不会从scss SASS中获取margin-bottom
- palantir-foundry - 如何分配热键以触发 Slate 中的事件
- sql-server - SQL Server XML 架构验证;无效内容错误;期望将次要事件设置为 0 的元素?