首页 > 解决方案 > 将列表项映射到 tensorflow 数据集字典

问题描述

我正在尝试将图像信息映射到由图像和标签字典组成的数据集。

parse_function()应该只从 2 个文件名路径和标签列表中解码。

 def parse_function(filename, label):
    image_string = tf.io.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string)
    image_resized = tf.image.resize(image_decoded, [4, 4])

    return image_resized, label

def dataset_maker(list_sample_paths, list_labels):

    filenames = tf.constant(list_sample_paths)
    labels = tf.constant(list_labels)

    dataset = tf.data.Dataset.from_tensor_slices({"image": filenames, "label": labels})
    dataset = dataset.map(parse_function)

training_dataset = dataset_maker(list_training_sample_paths, list_training_sample_labels)

但我收到此错误消息

TypeError: tf__parse_function() missing 1 required positional argument: 'label'

在这种情况下我需要使用字典理解吗?非常感谢解决此问题的任何帮助。谢谢!

在 Srihari Humbarwadi 回复后添加此信息以使用元组解决它: 我想获得一个字典结构,因为我用 Mnist 为我的模型下雨了。

一个随机的 Mnist 示例具有以下结构:

{'image': <tf.Tensor: id=140275, shape=(28, 28, 1), dtype=uint8, numpy=array([[[  0],[  0],[  0]],dtype=uint8)>, 'label': <tf.Tensor: id=140276, shape=(), dtype=int64, numpy=6>}

标签: pythontensorflowdictionarymappingdictionary-comprehension

解决方案


您不需要以字典的形式传递文件名和标签列表。你可以通过传递一个元组来让它工作,即。(filenames, labels). 这是我使用的完整代码:

from glob import glob
import numpy as np
import tensorflow as tf

print('TensorFlow:', tf.__version__)

list_training_sample_paths = sorted(glob('images/*'))
# random integer labels
list_training_sample_labels = np.random.randint(low=0, high=5, size=[len(list_training_sample_paths)])

def parse_function(filename, label):
    image_string = tf.io.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string)
    image_resized = tf.image.resize(image_decoded, [4, 4])

    return image_resized, label

def dataset_maker(list_sample_paths, list_labels):

    filenames = tf.constant(list_sample_paths)
    labels = tf.constant(list_labels)

    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
    dataset = dataset.map(parse_function)
    return dataset

training_dataset = dataset_maker(list_training_sample_paths, list_training_sample_labels)
tf.data.experimental.get_structure(training_dataset)

输出

TensorFlow: 2.2.0-rc2
(TensorSpec(shape=(4, 4, None), dtype=tf.float32, name=None), TensorSpec(shape=(),dtype=tf.int64, name=None))

推荐阅读