首页 > 解决方案 > 在急切模式下对字符串张量调用 map 时,需要一个类似字节的对象,而不是“张量”

问题描述

我正在尝试使用 TF.dataset.map 移植此旧代码,因为我收到了弃用警告。

从 TFRecord 文件中读取一组自定义原型的旧代码:

record_iterator = tf.python_io.tf_record_iterator(path=filename)
for record in record_iterator:
    example = MyProto()
    example.ParseFromString(record)

我正在尝试使用渴望模式和地图,但出现此错误。

def parse_proto(string):
      proto_object = MyProto()
      proto_object.ParseFromString(string)
dataset = tf.data.TFRecordDataset(dataset_paths)
parsed_protos = raw_tf_dataset.map(parse_proto)

此代码有效:

for raw_record in raw_tf_dataset:                                                                                                                                         
    proto_object = MyProto()                                                                                                                                              
    proto_object.ParseFromString(raw_record.numpy())                                                                                                                                 

但是地图给了我一个错误:

TypeError: a bytes-like object is required, not 'Tensor'

什么是使用参数映射的函数结果并将它们视为字符串的正确方法?

标签: tensorflowtensorflow-datasets

解决方案


您需要从张量中提取字符串并在map函数中使用。以下是要在代码中实现的步骤。

  1. 你必须用tf.py_function(get_path, [x], [tf.float32]). 您可以在此处找到有关 tf.py_function的更多信息。在tf.py_function中,第一个参数是map函数的名称,第二个参数是要传递给map函数的元素,最后一个参数是返回类型。
  2. 您可以使用bytes.decode(file_path.numpy())in map 函数获取您的字符串部分。

所以修改你的程序如下,

parsed_protos = raw_tf_dataset.map(parse_proto)

parsed_protos = raw_tf_dataset.map(lambda x: tf.py_function(parse_proto, [x], [function return type]))

也修改parse_proto如下,

def parse_proto(string):
      proto_object = MyProto()
      proto_object.ParseFromString(string)

def parse_proto(string):
      proto_object = MyProto()
      proto_object.ParseFromString(bytes.decode(string.numpy()))

在下面的简单程序中,我们tf.data.Dataset.list_files用于读取图像的路径。接下来在map函数中,我们使用读取图像load_img,然后执行该tf.image.central_crop函数以裁剪图像的中心部分。

代码 -

%tensorflow_version 2.x
import tensorflow as tf
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array, array_to_img
from matplotlib import pyplot as plt
import numpy as np

def load_file_and_process(path):
    image = load_img(bytes.decode(path.numpy()), target_size=(224, 224))
    image = img_to_array(image)
    image = tf.image.central_crop(image, np.random.uniform(0.50, 1.00))
    return image

train_dataset = tf.data.Dataset.list_files('/content/bird.jpg')
train_dataset = train_dataset.map(lambda x: tf.py_function(load_file_and_process, [x], [tf.float32]))

for f in train_dataset:
  for l in f:
    image = np.array(array_to_img(l))
    plt.imshow(image)

输出 -

在此处输入图像描述

希望这能回答你的问题。快乐学习。


推荐阅读