首页 > 解决方案 > 如何从张量流数据集中解包数据?

问题描述

这是我关于从 tfrecord 加载数据的代码:

def read_tfrecord(tfrecord, epochs, batch_size):

    dataset = tf.data.TFRecordDataset(tfrecord)

    def parse(record):
        features = {
            "image": tf.io.FixedLenFeature([], tf.string),
            "target": tf.io.FixedLenFeature([], tf.int64)
        }
        example = tf.io.parse_single_example(record, features)
        image = decode_image(example["image"])
        label = tf.cast(example["target"], tf.int32)
        return image, label

    dataset = dataset.map(parse)
    dataset = dataset.shuffle(buffer_size=10000)        
    dataset = dataset.prefetch(buffer_size=batch_size)  #
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.repeat(epochs)

    return dataset


x_train, y_train = read_tfrecord(tfrecord=train_files, epochs=EPOCHS, batch_size=BATCH_SIZE)

我收到以下错误:

ValueError: too many values to unpack (expected 2)

我的问题是:

如何从数据集中解包数据?

标签: pythontensorflow2.0tensorflow-datasets

解决方案


你可以试试这个解决方案:

dataset = read_tfrecord(tfrecord=train_files, epochs=EPOCHS, batch_size=BATCH_SIZE)

iterator = iter(dataset)

x, y = next(iterator)

推荐阅读