tensorflow - 如何在 tensorflow < 2.0 的 tf.dataset.map 中捕获和引发 tfrecord 错误
问题描述
使用 tfdataset (tensorflow < 2.0) 我遇到了一个罕见的损坏文件,该文件无法转换为正确的尺寸。每个 tfrecord 行都有一个要读取的图像的文件名和尺寸。我想捕获错误并打印此文件名,以便将其删除。
将解析器的 try catch 放在哪里以提高文件名?
def _parse_fn(example):
# Define features
features = {
'image/filename': tf.io.FixedLenFeature([], tf.string),
"image/height": tf.FixedLenFeature([], tf.int64),
"image/width": tf.FixedLenFeature([], tf.int64),
}
# Load one example and parse
example = tf.io.parse_single_example(example, features)
# Load image from file
filename = tf.cast(example["image/filename"], tf.string)
loaded_image = tf.read_file(filename)
loaded_image = tf.image.decode_image(loaded_image, 3)
# Reshape to known shape
image_rows = tf.cast(example['image/height'], tf.int32)
image_cols = tf.cast(example['image/width'], tf.int32)
#Wrap in a try catch and report file failure
try:
loaded_image = tf.reshape(loaded_image,
tf.stack([image_rows, image_cols, 3]),
name="cast_loaded_image")
except tf.errors.InvalidArgumentError as e:
print("Image filename: {} yielded {}".format(filename, e))
# Maps the parser on every filepath in the array. You can set the number of parallel loaders here. Wrap in a catch loop to report errors
dataset = dataset.map(_parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
未捕获错误且未打印文件名
File "/apps/tensorflow/1.14.0.cuda10.gpu/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1458, in __call__
run_metadata_ptr)
tensorflow.python.framework.errors_impl.InvalidArgumentError:
2 root error(s) found.
(0) Invalid argument: Input to reshape is a tensor with 480000 values, but the requested shape has 259518
[[{{node cast_loaded_image}}]]
[[IteratorGetNext]]
[[replica_3/retinanet/bn2c_branch2a/FusedBatchNorm/ReadVariableOp/_987]]
(1) Invalid argument: Input to reshape is a tensor with 480000 values, but the requested shape has 259518
[[{{node cast_loaded_image}}]]
[[IteratorGetNext]]
0 successful operations.
3 derived errors ignored.
解决方案
A colleague suggested this: I will accept if there are no other answers, I think this has some value.
One strategy is to use
dataset = dataset.apply(tf.data.experimental.ignore_errors())
and apply a parsing that returns the filename of each tfrecord. After running this and comparing it to length of the original records, you can find which images was corrupted. I think there must be other solutions, but if you have access to the original set of images that were wrapped into tfrecords, you can diff the lists and get the missing filenames.
推荐阅读
- python - 是否可以从 Google Cloud 上的 Flask webapp 调用 matlab.engine?
- python - 功能受阻(Python)
- google-analytics - 谷歌全球网站标签:如何覆盖页面引荐来源网址
- scala - scala spark中的行到向量
- csv - 如何在空手道的单个特征文件中使用两个或多个 csv 文件?
- javascript - 电子 preload.js 文件中的 contextBridge 数量
- bots.business - 如何更改机器人业务的 webhook url 以将数据发送给不同的用户?
- apache-superset - 为什么超集 load_examples 说“管理员用户不存在”,即使存在一个?
- php - 在 WooCommerce 4 中设置购物车过期间隔
- c - 在 void* 中使用指针算法