tensorflow - 来自 Keras 序列类的 TF 数据集
问题描述
我想我会分享一些我花了一段时间才弄清楚的东西:用 TF Dataset 对象轻松包装现有的 Keras 序列类。在遵循教程并从 TF 1.X 和 Keras 迁移到 TF 2.XI 之后,终于弄清楚了如何用最少的代码完成它。希望我不是唯一一个为此苦苦挣扎的人,其他人会发现这很有帮助:)
几个假设:
序列类加载数据和标签
标签具有与源数据相同的形状(除了通道)(即这是我用来训练 U-Net 的东西)
数据格式是通道最后
import tensorflow as tf
def DatasetFromSequenceClass(sequenceClass, stepsPerEpoch, nEpochs, batchSize, dims=[512,512,3], n_classes=2, data_type=tf.float32, label_type=tf.float32):
# eager execution wrapper
def DatasetFromSequenceClassEagerContext(func):
def DatasetFromSequenceClassEagerContextWrapper(batchIndexTensor):
# Use a tf.py_function to prevent auto-graph from compiling the method
tensors = tf.py_function(
func,
inp=[batchIndexTensor],
Tout=[data_type, label_type]
)
# set the shape of the tensors - assuming channels last
tensors[0].set_shape([batchSize, dims[0], dims[1], dims[2]]) # [samples, height, width, nChannels]
tensors[1].set_shape([batchSize, dims[0], dims[1], n_classes]) # [samples, height, width, nClasses for one hot]
return tensors
return DatasetFromSequenceClassEagerContextWrapper
# TF dataset wrapper that indexes our sequence class
@DatasetFromSequenceClassEagerContext
def LoadBatchFromSequenceClass(batchIndexTensor):
# get our index as numpy value - we can use .numpy() because we have wrapped our function
batchIndex = batchIndexTensor.numpy()
# zero-based index for what batch of data to load; i.e. goes to 0 at stepsPerEpoch and starts cound over
zeroBatch = batchIndex % stepsPerEpoch
# load data
data, labels = sequenceClass[zeroBatch]
# convert to tensors and return
return tf.convert_to_tensor(data), tf.convert_to_tensor(labels)
# create our data set for how many total steps of training we have
dataset = tf.data.Dataset.range(stepsPerEpoch*nEpochs)
# return dataset using map to load our batches of data, use TF to specify number of parallel calls
return dataset.map(LoadBatchFromSequenceClass, num_parallel_calls=tf.data.experimental.AUTOTUNE)
使用该功能,您可以将训练更新为如下所示:
# load our data as tensorflow datasets
training = DatasetFromSequenceClass(trainingSequence, training_steps, nEpochs, batchSize, dims=shp, n_classes=nClasses)
validation = DatasetFromSequenceClass(validationSequence, validation_steps, nEpochs, batchSize, dims=shp, n_classes=nClasses)
# train
model_object.fit(training,
steps_per_epoch=training_steps,
validation_data=validation,
validation_steps=validation_steps,
epochs=nEpochs,
callbacks=callbacks,
verbose=1)
从这里开始,Dataset API 有很多其他选项(如预取),但这应该是一个很好的起点。
解决方案
推荐阅读
- c++ - 在 ofstream 写入期间检测到空间不足,stream.fail() 无法工作
- c# - 从结构类型获取静态属性
- javascript - 如何像 cricbuzz 网站一样在分页中创建标签?
- reactjs - 将映射数据传递给 Material-UI 模态
- sql - where 子句中的子查询
- javascript - Highcharts 上的多条 xAxis 线
- cors - 如何使用 API Gateway API 在 AWS API Gateway 中为 Lambda 集成启用 CORS 标头?
- elasticsearch - 如何在字符串字段长度和以“99”结尾的字段的条件下查询
- ios - 如何通过 iOS(swift)将 Firebase Auth 与 MySQL DB 同步?
- flutter - 生成异常“Android 依赖项 'androidx.appcompat:appcompat' 的编译 (1.0.1) 和运行时 (1.0.2) 类路径版本不同”