首页 > 解决方案 > “正常” python 对象会自动成为张量吗?例如 "a" 或 1

问题描述

我很难理解为什么这个功能有效:

def tf_load_data(training, batch_size, img_path_list, label_list):

    # Arguments:
    # path_to_image: a Tensor of type string
    # returns 3-D float Tensor of shape [new_height, new_width, channels]
    def get_img(path_to_img):
        # load the raw data, encoded as string.
        raw_img = tf.io.read_file(path_to_img)
        # Creates a 3D uint8 tensor.
        img = ts.io.decode_png(raw_img, channels=3)  # pictures are not saved as Grayscale
        # Changes the values in the tensor to be floats in [0,1). -- Normalization
        img = ts.image.convert_image_dtype(img, tf.float32)
        # Resize all pictures to the same format.
        return ts.image.resize(img, [constant.IMG_WIDTH, constant.IMG_HEIGHT])

    # Arguments:
    # label_string: as byte:32 which represents a string
    # path_to_image: as byte:32 which represents a string
    # returns a pair of two Tensors
    def get_pair(path_to_img, label_string):
        return get_img(path_to_img), lable_string

    # Arguments: -- function is use together with tf.data.Dataset.map or tf.data.Dataset.apply
    # img: is a Tensor of type String
    # label: is a Tensor of type String
    # return: the type is the same as input
    def pre_process(img, label):
        # Do all the pre-processing:
        return img, label

    dataset = tf.data.Dataset.from_tensor_slices((img_path_list, label_list))
    dataset_tensor = dataset.map(map_func=get_pair, num_parallel_calls=None)

img_path_list 和 label_list 是字符串类型的列表。

我不明白的是,这显然 dataset = tf.data.Dataset.from_tensor_slices((img_path_list, label_list)) 是一个包含元组 () 的张量。因此,当我运行时,.map(map_func=get_pair, num_parallel_calls=None)两个字符串作为元组在get_pair(path_to_img, label_string):函数中传递。然后将这些字符串之一传递给get_img(path_to_img)函数,最后传递给tf.io.read_file(path_to_img). 问题是io.read_file()需要输入:“字符串类型的张量”(参见文档:https ://www.tensorflow.org/api_docs/python/tf/io/read_file )。但是 string != 到字符串类型的张量:

isinstance(tf.constant["hello"], tf.Tensor) == True

isinstance("hello", tf.Tensor) == False也:isinstance(["hello"], tf.Tensor) == False

谢谢你的帮助!

标签: pythonpython-3.xtensorflowtensorflow2.0tensorflow-datasets

解决方案


推荐阅读