首页 > 解决方案 > 如何在 tensorflow < 2.0 的 tf.dataset.map 中捕获和引发 tfrecord 错误

问题描述

使用 tfdataset (tensorflow < 2.0) 我遇到了一个罕见的损坏文件,该文件无法转换为正确的尺寸。每个 tfrecord 行都有一个要读取的图像的文件名和尺寸。我想捕获错误并打印此文件名,以便将其删除。

将解析器的 try catch 放在哪里以提高文件名?

def _parse_fn(example):
    # Define features
    features = {
        'image/filename': tf.io.FixedLenFeature([], tf.string),
        "image/height": tf.FixedLenFeature([], tf.int64),
        "image/width": tf.FixedLenFeature([], tf.int64),

    }

    # Load one example and parse
    example = tf.io.parse_single_example(example, features)

    # Load image from file
    filename = tf.cast(example["image/filename"], tf.string)
    loaded_image = tf.read_file(filename)
    loaded_image = tf.image.decode_image(loaded_image, 3)

    # Reshape to known shape
    image_rows = tf.cast(example['image/height'], tf.int32)
    image_cols = tf.cast(example['image/width'], tf.int32)

    #Wrap in a try catch and report file failure
    try:
        loaded_image = tf.reshape(loaded_image,
                              tf.stack([image_rows, image_cols, 3]),
                              name="cast_loaded_image")
    except tf.errors.InvalidArgumentError as e:
        print("Image filename: {} yielded {}".format(filename, e))
    # Maps the parser on every filepath in the array. You can set the number of parallel loaders here. Wrap in a catch loop to report errors
    dataset = dataset.map(_parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)

未捕获错误且未打印文件名

File "/apps/tensorflow/1.14.0.cuda10.gpu/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1458, in __call__
run_metadata_ptr)
tensorflow.python.framework.errors_impl.InvalidArgumentError:
2 root error(s) found.
  (0) Invalid argument: Input to reshape is a tensor with 480000 values, but the requested shape has 259518
     [[{{node cast_loaded_image}}]]
     [[IteratorGetNext]]
     [[replica_3/retinanet/bn2c_branch2a/FusedBatchNorm/ReadVariableOp/_987]]
  (1) Invalid argument: Input to reshape is a tensor with 480000 values, but the requested shape has 259518
     [[{{node cast_loaded_image}}]]
     [[IteratorGetNext]]
0 successful operations.
3 derived errors ignored.

标签: tensorflowtfrecord

解决方案


A colleague suggested this: I will accept if there are no other answers, I think this has some value.

One strategy is to use

dataset = dataset.apply(tf.data.experimental.ignore_errors())

and apply a parsing that returns the filename of each tfrecord. After running this and comparing it to length of the original records, you can find which images was corrupted. I think there must be other solutions, but if you have access to the original set of images that were wrapped into tfrecords, you can diff the lists and get the missing filenames.


推荐阅读