首页 > 解决方案 > 如何正确地将函数映射到数据集的每条记录

问题描述

目标

我正在尝试为图像分割准备我的数据集。.tfrecord我使用以下代码将所有图像及其相关注释转换为文件:

writer = tf.python_io.TFRecordWriter(tfrecords_filename)

for img_path, annotation_path in filename_pairs:
    img = np.array(Image.open(img_path))
    annotation = np.array(Image.open(annotation_path))
    height = img.shape[0]
    width = img.shape[1]

    img_raw = img.tostring()
    annotation_raw = annotation.tostring()

    example = tf.train.Example(features=tf.train.Features(feature={
        'height': _int64_feature(height),
        'width': _int64_feature(width),
        'image_raw': _bytes_feature(img_raw),
        'mask_raw': _bytes_feature(annotation_raw)}))

    writer.write(example.SerializeToString())

现在,我正在尝试将这些记录加载到 TF 数据集中:

dataset = tf.data.TFRecordDataset(training_filenames).shuffle(1000).repeat(4).batch(32)

实验

现在,如果我尝试在 this 中显示第一个图像/注释对dataset,它会按预期工作:

batch = next(iter(dataset))
tensor = batch[0]

image, annotation = _parse_function(tensor)
annotation = np.squeeze(annotation.numpy()[:, :], axis=2)
plt.figure()
plt.imshow(image.numpy())
plt.imshow(annotation, alpha=0.5)
plt.show()

我在其中预处理记录_parse_function以提取特征(我有意在急切执行模式下使用 TensorFlow):

def _parse_function(example_proto):
    features = {'height': tf.FixedLenFeature(1, tf.int64),
                'width': tf.FixedLenFeature(1, tf.int64),
                'image_raw': tf.FixedLenFeature([], tf.string),
                'mask_raw': tf.FixedLenFeature([], tf.string)}
    parsed_features = tf.parse_single_example(example_proto, features)

    annotation = tf.decode_raw(parsed_features['mask_raw'], tf.uint8)

    height = tf.cast(parsed_features['height'], tf.int32)
    width = tf.cast(parsed_features['width'], tf.int32)
    height = height.numpy()[0]
    width = width.numpy()[0]

    image = tf.decode_raw(parsed_features['image_raw'], tf.uint8)

    image = tf.reshape(image, tf.stack([height, width, 3]))
    annotation = tf.reshape(annotation, tf.stack([height, width, 1]))

    return image, annotation

实际问题

当然,我宁愿将整体dataset变成可直接用于训练分割模型的东西。

但是,如果我尝试对整个内容进行预处理dataset以将其转换为一组功能,使用dataset.map(_parse_function)example_proto所输入的_parse_function内容似乎与我在执行时得到的内容不同next(iter(dataset))[0]。更准确地说,它是 0 阶张量(所以只是一个量级),因此无法正确提取特征。

我对 TF 还是比较陌生,不太明白为什么会这样,也不明白这个张量代表什么。

是否map批量调用回调函数而不是底层示例?我尝试删除,batch(32)但文档说默认行为是生成大小为 1 的批次,这不一定能解决问题。

任何帮助,将不胜感激!

标签: pythontensorflow

解决方案


推荐阅读