首页 > 解决方案 > 如何使用谓词按标签过滤我的自定义张量流数据集

问题描述

我使用以下代码从本地图像目录创建了一个 tensorflow 数据集:

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        train_dir=my_img_dir,
        validation_split=0.2,
        subset="training",
        seed=123,
        image_size=(256,256),
        batch_size=32)

我有 1301 个标签。这是一个分类图像问题。由于类别众多,我选择尝试连体网络我正在尝试创建成对的图像来提供我的连体网络,所以我需要相同类别的图像对和不同类别的图像对。为此,我尝试使用此代码进行过滤(例如仅过滤标签== 314)

    def predicate(images, labels, allowed_labels=tf.constant([314],dtype=tf.int32)):
    #label = x['label']
    isallowed = tf.equal(allowed_labels, tf.cast(labels, tf.int32))
    reduced = tf.reduce_sum(tf.cast(isallowed, tf.int32))
    return tf.greater(reduced, tf.constant(0))


#dataset_314 = train_ds.filter(predicate)

dataset_314 = train_ds.filter(lambda img, label: label == 314)

上面的每个过滤器都会给我一个错误。使用 lambda 的过滤器给出了这个错误:

ValueError: `predicate` return type must be convertible to a scalar boolean tensor. Was TensorSpec(shape=(None,), dtype=tf.bool, name=None).

使用自定义函数谓词,我没有错误,但它根本没有过滤。我想了解它是如何工作的。如何过滤,知道数据集是分批加载的。如何返回具有相同类别/标签的两个图像的元组?谢谢您的帮助。

标签: pythontensorflow2.0tensorflow-datasets

解决方案


推荐阅读