首页 > 解决方案 > 如何获取 tf.data.Dataset 的长度(data_size / batch_size)?

问题描述

我想获取我的 tf.data.Dataset 的长度。(数据大小/批处理大小)

在 Pytorch 中,我可以通过简单的代码得到这个:

length = len(data_loader)

但是,它在 tensorflow 2.0 中不起作用。

我怎么得到这个?

标签: tensorflowtensorflow2.0

解决方案


在 TensorFlow 2.0 中,您创建一个tf.data.Dataset对象,即 Python 可迭代对象。

在循环遍历所有元素之前,您不会事先知道数据集中有多少元素。

因此,假设您以这种方式创建了一个数据集:

batch_size = 12
dataset = tf.data.Dataset.from_tensor_slices(something).batch(batch_size)

您可以通过这种方式获得批次总数:

number_of_batches = len([_ for _ in iter(dataset)])

推荐阅读