首页 > 解决方案 > 来自 Keras 序列类的 TF 数据集

问题描述

我想我会分享一些我花了一段时间才弄清楚的东西:用 TF Dataset 对象轻松包装现有的 Keras 序列类。在遵循教程并从 TF 1.X 和 Keras 迁移到 TF 2.XI 之后,终于弄清楚了如何用最少的代码完成它。希望我不是唯一一个为此苦苦挣扎的人,其他人会发现这很有帮助:)

几个假设:

import tensorflow as tf

def DatasetFromSequenceClass(sequenceClass, stepsPerEpoch, nEpochs, batchSize, dims=[512,512,3], n_classes=2, data_type=tf.float32, label_type=tf.float32):
    # eager execution wrapper
    def DatasetFromSequenceClassEagerContext(func):
        def DatasetFromSequenceClassEagerContextWrapper(batchIndexTensor):
            # Use a tf.py_function to prevent auto-graph from compiling the method
            tensors = tf.py_function(
                func,
                inp=[batchIndexTensor],
                Tout=[data_type, label_type]
            )

            # set the shape of the tensors - assuming channels last
            tensors[0].set_shape([batchSize, dims[0], dims[1], dims[2]])   # [samples, height, width, nChannels]
            tensors[1].set_shape([batchSize, dims[0], dims[1], n_classes]) # [samples, height, width, nClasses for one hot]
            return tensors
        return DatasetFromSequenceClassEagerContextWrapper

    # TF dataset wrapper that indexes our sequence class
    @DatasetFromSequenceClassEagerContext
    def LoadBatchFromSequenceClass(batchIndexTensor):
        # get our index as numpy value - we can use .numpy() because we have wrapped our function
        batchIndex = batchIndexTensor.numpy()

        # zero-based index for what batch of data to load; i.e. goes to 0 at stepsPerEpoch and starts cound over
        zeroBatch = batchIndex % stepsPerEpoch

        # load data
        data, labels = sequenceClass[zeroBatch]

        # convert to tensors and return
        return tf.convert_to_tensor(data), tf.convert_to_tensor(labels)

    # create our data set for how many total steps of training we have
    dataset = tf.data.Dataset.range(stepsPerEpoch*nEpochs)

    # return dataset using map to load our batches of data, use TF to specify number of parallel calls
    return dataset.map(LoadBatchFromSequenceClass, num_parallel_calls=tf.data.experimental.AUTOTUNE)

使用该功能,您可以将训练更新为如下所示:

# load our data as tensorflow datasets
training = DatasetFromSequenceClass(trainingSequence, training_steps, nEpochs, batchSize, dims=shp, n_classes=nClasses)
validation = DatasetFromSequenceClass(validationSequence, validation_steps, nEpochs, batchSize, dims=shp, n_classes=nClasses)

# train
model_object.fit(training,
    steps_per_epoch=training_steps,
    validation_data=validation,
    validation_steps=validation_steps,
    epochs=nEpochs,
    callbacks=callbacks,
    verbose=1)

从这里开始,Dataset API 有很多其他选项(如预取),但这应该是一个很好的起点。

标签: tensorflowkerastf.keras

解决方案


推荐阅读