首页 > 解决方案 > 如何从 TFRecordDataset 获取张量的形状

问题描述

我在培训 TFRecord 中写入了以下功能:

feature = {'label': _int64_feature(gt),
           'image': _bytes_feature(tf.compat.as_bytes(im.tostring())),
           'height': _int64_feature(h),
           'width': _int64_feature(w)}

我正在阅读它:

train_dataset = tf.data.TFRecordDataset(train_file)
train_dataset = train_dataset.map(parse_func)
train_dataset = train_dataset.shuffle(buffer_size=1)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.prefetch(batch_size)

而我的 parse_func 看起来像这样:

def parse_func(ex):
    feature = {'image': tf.FixedLenFeature([], tf.string),
               'label': tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
               'height': tf.FixedLenFeature([], tf.int64),
               'width': tf.FixedLenFeature([], tf.int64)}
    features = tf.parse_single_example(ex, features=feature)
    image = tf.decode_raw(features['image'], tf.uint8)
    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)
    im_shape = tf.stack([width, height])
    image = tf.reshape(image, im_shape)
    label = tf.cast(features['label'], tf.int32)
    return image, label

现在,我想获得图像标签的形状,例如:

image.get_shape().as_list()

打印
[None, None, None]
而不是
[None, 224, 224] (图像的大小(批量、宽度、高度))

有没有什么函数可以给我这些张量的大小?

标签: tensorflowtensortensorflow-datasetstfrecord

解决方案


由于您的地图函数“parse_func”作为操作是图形的一部分,并且它不知道输入的固定大小和先验标签,因此使用 get_shape() 不会返回预期的固定形状。

如果您的图像,标签形状是固定的,作为一个黑客,您可以尝试重塑您的图像,具有已知大小的标签(这实际上不会做任何事情,但会明确设置输出张量的大小) .

前任。图像 = tf.reshape(图像,[224,224])

有了这个,您应该能够按预期获得 get_shape() 结果。


推荐阅读