tensorflow-datasets - 关于 tf.repeat().batch(batch_size)
问题描述
我正在研究张量流。关于tensorflow.data.Dataset中的repeat函数,如果repeat函数repeat()中没有参数,则张量应该无限重复。但是,当没有参数的重复函数与循环语句下的批处理函数结合使用时,它会创建一个没有无限重复的结果,如下所示。我无法理解这个过程。你能用下面的例子解释一下重复功能吗?谢谢你!
for count_batch in ds_counter.repeat().batch(10).take(10):
print(count_batch.numpy())
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 0 1 2 3 4]
[ 5 6 7 8 9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 0 1 2 3 4]
[ 5 6 7 8 9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]
解决方案
由于您使用.take(10)
的是最后一个链接方法,因此结果数据集因此被限制为只有 10 个样本。这里的单个sample
将是单个批次中的所有元素。有 10 批无限重复,您将使用.take(10)
. 将您的代码更改为以下应该会给您预期的结果。
ds_counter = tf.data.Dataset.range(25)
for count_batch in ds_counter.repeat().batch(10):
print(count_batch.numpy())
推荐阅读
- javascript - react-transition-group 中的退出延迟动画
- sql-server - Sqlserver 查询按行显示
- python-2.7 - AWS 胶水无法读取 MySQL 数据库中的中文字符
- python - 根据多个值检查和更新列
- tensorflow - 如何在 Tensorflow 的计算图中用另一个变量替换一个变量?
- javascript - Angular 6 - 在依赖注入之前使用异步调用初始化服务
- frameworks - podspec source 和 source_files 没有采用本地路径
- ubuntu - pip安装后找不到cget
- python - 如何传递python命令行参数?
- android - 在 Firebase Google Analytics 中看不到自定义广告系列参数结果