python - tf.dataset does not append batches
问题描述
I want to get tf.dataset to work. The code example below is working, but since I used .batch(30)
I would expect that the output is in the form of (30, 300, 300, 1)?
import tensorflow as tf
import numpy as np
input_array = np.random.normal(size=(300, 300, 3))
def own_generator():
yield (input_array, input_array)
dataset = tf.data.Dataset.from_generator(own_generator, (tf.float32, tf.float32)).batch(30)
data_iter = dataset.make_initializable_iterator()
sess = tf.Session()
sess.run(data_iter.initializer)
test_arr = sess.run(data_iter.get_next())
for tuple_elemnt in test_arr:
print(tuple_elemnt.shape)
The output is:
(1, 300, 300, 3)
(1, 300, 300, 3)
解决方案
The generator was falsely programmed. This is the working example:
import tensorflow as tf
import numpy as np
input_array = np.random.normal(size=(300, 300, 3))
def own_generator():
while True:
yield input_array
dataset = tf.data.Dataset.from_generator(own_generator, tf.float32).batch(30)
data_iter = dataset.make_initializable_iterator()
sess = tf.Session()
sess.run(data_iter.initializer)
test_arr = sess.run(data_iter.get_next())
print(test_arr.shape)
推荐阅读
- xml - 无法从相关路径将数据存储在 xsl 变量中
- vue.js - Vuex 11:17 错误“状态”已定义但从未使用过 no-unused-vars
- linux - VS Code 的远程 WSL 扩展中的“新窗口”和 VS Code 窗口有什么区别?
- flutter - 提供商的firebase auth电话(两个屏幕之间没有切换)
- typescript - TSyringe 注入所有实现某个接口的类
- asp.net - CSS 样式仅适用于 ASP.NET Core 5.0 MVC-app 中的主视图
- git - 预提交无法在嵌套的 go 模块中发现任何 golang 文件
- google-colaboratory - Colab 笔记本保存失败
- swagger - Swagger UI:如何隐藏 Nest.js 控制器方法参数输入字段?
- python - 用户在文本字段中键入时如何打开组合框的下拉列表?