首页 > 解决方案 > Tensorflow 数据集交错读取图像文件路径

问题描述

我正在尝试使用 tensorflow 数据集 API 来加载可以在网络存储上的 tif 图像。我传递给 tensorflow.Dataset.interleave 函数的 map_func 没有传递我期望的字符串文件名,而是一个具有 dtype 字符串的张量。

我尝试使用 sess.run 和 tensor.eval() 评估此张量(也将当前会话作为会话参数传递),但 tensorflow 引发 ValueError:“ValueError:Fetch 参数不能解释为张量。(张量张量("arg0:0", shape=(), dtype=string) 不是该图的元素。) or("arg0:0", shape=(), dtype=string) 不是该图的元素。 )”。

我的张量流数据管道的一个例子

class DataLoader:

...

    def setup(self):

        ...

        tf.data.Dataset.from_tensor_slices(
            (
                self.training_filenames, # a python list of strings
                self.training_label_filenames # a python list of strings
            )
        )
        .apply(tf.data.experimental.filter_for_shard(
            self.shard_count,
            self.shard_index))
        .repeat()
        .shuffle(buffer_size=self.training_data_shuffle_buffer_size)
        .interleave(
            lambda data_filepath, label_filepath: (
                self.preprocess_training_data(data_filepath, label_filepath)
            ),
            cycle_length=tf.data.experimental.AUTOTUNE,
            num_parallel_calls=tf.data.experimental.AUTOTUNE
        )
        .batch(self.training_data_batch_size)
        .prefetch(self.training_data_batch_size)

        ...

    def preprocess_training_data(self, data_filepath, label_filepath):
        data = tifffile.imread(self.session.run(data_filepath).decode())
        data_resize = (self.training_data_shape[0], self.training_data_shape[1])
        data_transpose = (1, 0, 2)
        data_scale = 255.0
        data_dtype = self.training_data_type.as_numpy_dtype()

        data = numpy.transpose(
            cv2.resize(data, data_resize), data_transpose
        ).astype(data_dtype) / data_scale

        label = tifffile.imread(self.session.run(label_filepath))
        label = numpy.transpose(
                    numpy.expand_dims(
                        cv2.resize(
                            label,
                            data_resize),
                        2
                    ),
                    data_transpose
                )

        weights = [
            self.vec_class_weights(current_label)
                .astype(self.label_data_type.as_numpy_dtype())
            for current_label
            in label
        ]

        return data, label, weights

我希望我的 preprocess_training_data 函数被传递字符串,或者我将能够评估我的函数从数据集交错转换传递的张量,该转换将评估为字符串。

标签: tensorflow-datasets

解决方案


推荐阅读