首页 > 解决方案 > TFRecord - 将 png 转换为字节

问题描述

创建 tfrecord 的代码:

def convert(self):
    with tf.python_io.TFRecordWriter(self.tfrecord_out) as writer:
        example = self._convert_image()
        writer.write(example.SerializeToString())

def _convert_image(self):
    for (path, label) in zip(self.image_paths, self.labels):
        label = int(label)
        # Read image data in terms of bytes
        with open(path, 'rb') as fid:
            png_bytes = fid.read()

        example = tf.train.Example(features=tf.train.Features(feature={
            'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[png_bytes]))
            }))
    return example

我的问题是当我从文件中读取图像无法正确解码时:

def parse(self, serialized):
    features = \
        {
            'image': tf.FixedLenFeature([], tf.string)
        }

    parsed_example = tf.parse_single_example(serialized=serialized,
                                                 features=features)

    image_raw = parsed_example['image']
    image = tf.image.decode_png(contents=image_raw, channels=3, dtype=tf.uint8)
    image = tf.cast(image, tf.float32)
    return image`

有谁知道这是为什么?

垃圾图片

标签: tensorflow

解决方案


找到了解决方案,希望我的愚蠢错误对其他人有所帮助。

在将张量重塑为张量板的 4 个维度时,[batch_size, height, width, channels]我切换了宽度和高度。

正确的重塑代码是:

x_reshaped = session.run(tf.reshape(tensor=decoded_png_uint8, shape=[batch_size, height, width, channels], name="x_reshaped"))

但我有shape=[batch_size, width, height, channels]。呃,好吧。每天都是上学日。

正确的输出


推荐阅读