首页 > 解决方案 > 具有恒定大小批次的 tf.data.Dataset

问题描述

我有一个包含 19 个元素和批量大小为 10 的数据集。我将数据集设置为连续迭代相同的元素,但我注意到最后一批只有 4 个元素而不是 5 个,然后从 5、5、 5、4 等等。

如何强制迭代器用来自下一次迭代的元素填充较短的批次,以便所有批次具有相同的大小?

PS只是为了理解,这不是训练模型时的明显行为吗?

标签: tensorflowtensorflow-datasets

解决方案


要具有此行为,.repeat()应在batch()or之前调用该方法padded_batch()。所以:

file_names = [...]
def my_map_func(record):
    ....
dataset = tf.data.TFRecordDataset(file_names)\
    .map(map_func=my_map_func)\
    .repeat()\  # here!
    .batch(5)

推荐阅读