tensorflow - 如何获取 tf.data.Dataset 的长度(data_size / batch_size)?
问题描述
我想获取我的 tf.data.Dataset 的长度。(数据大小/批处理大小)
在 Pytorch 中,我可以通过简单的代码得到这个:
length = len(data_loader)
但是,它在 tensorflow 2.0 中不起作用。
我怎么得到这个?
解决方案
在 TensorFlow 2.0 中,您创建一个tf.data.Dataset
对象,即 Python 可迭代对象。
在循环遍历所有元素之前,您不会事先知道数据集中有多少元素。
因此,假设您以这种方式创建了一个数据集:
batch_size = 12
dataset = tf.data.Dataset.from_tensor_slices(something).batch(batch_size)
您可以通过这种方式获得批次总数:
number_of_batches = len([_ for _ in iter(dataset)])
推荐阅读
- android-activity - Android - 活动/片段洗牌取决于来自服务器的某些事件
- python - Conda - ModuleNotFoundError:没有名为“火炬”的模块
- python - 如何将 json 响应分解为特定于名称值的多个部分?
- python - 如何使用图像和标签为 TensorFlow 创建数据集
- javascript - 如何从全局范围中删除 CSS 并仅导入到 React、Gatsby 中的特定页面
- memgraphdb - 由于并发操作而无法序列化:memgraph
- node.js - nodejs:全局this,相同的代码在文件和repl中运行时输出不同的结果
- ionic-framework - 在 Ionic 5 中安装 fcm-with-dependency-updated Cordova 插件时出现 Cocoapods 问题
- java - 在过去 24 小时内打印到控制台并通过电子邮件发送所有失败的作业、管道和多分支管道
- python - 从字符串中删除特殊字符,而不是用空格替换它们