首页 > 解决方案 > 解码的 TFRecord 大小错误

问题描述

我正在使用 TensorFlows DeepLab,想看看我的数据是否正确转换为 tfrecords。我使用预定义download_and_convert_ade20k.sh和以下代码将其解码回来:

import tensorflow as tf

feature_desc = {
    'image/encoded': tf.io.FixedLenFeature((), tf.string, default_value=''),
    'image/filename': tf.io.FixedLenFeature((), tf.string, default_value=''),
    'image/format': tf.io.FixedLenFeature((), tf.string, default_value='jpeg'),
    'image/height': tf.io.FixedLenFeature((), tf.int64, default_value=0),
    'image/width': tf.io.FixedLenFeature((), tf.int64, default_value=0),
    'image/segmentation/class/encoded': tf.io.FixedLenFeature((), tf.string, default_value=''),
    'image/segmentation/class/format': tf.io.FixedLenFeature((), tf.string, default_value='png'),
}

def _parse_image_function(example_proto):
    return tf.io.parse_single_example(example_proto, feature_desc)

raw_image_dataset = tf.data.TFRecordDataset('ADE20K/tfrecord/train-00000-of-00004.tfrecord')

parsed_image_dataset = raw_image_dataset.map(_parse_image_function)

for image_features in parsed_image_dataset:
    img = tf.io.decode_raw(image_features['image/encoded'], tf.uint8).numpy()
    seg = tf.io.decode_raw(image_features['image/segmentation/class/encoded'], tf.uint8).numpy()

我有两个问题。首先,我在编码之前不知道图像的数据类型。我猜tf.uint8是因为它是唯一一个也可以输出 0 到 255 之间的值的工作。另一方面,获得的图像是一维的(如预期的那样),但不足以将其重塑为原始形状。解码后的图像长度为 48836,而原始高度和宽度为 (512, 612),总共等于 313344 个像素。

标签: pythontensorflowtfrecord

解决方案


推荐阅读