python - 从 keras fit_generator() 到模型输入层的转换如何准确工作
问题描述
我正在处理图像数据和一些标量元数据(如头发颜色、眼睛颜色……)。我正在使用自己编写的生成器来使用 Keras.fit_generator()
功能。
该过程如下所示:
在应用了一些数据增强之后,我得到了((10,200,200,3),(10,),(10,),(10,),(10,))
我的数据集的形状(想象一下:我提取了 shape 的图像(200,200,3
)并将其中的 10 个 -> 堆叠在一起(10,200,200,3)
。因此,我将元数据复制了 10 次 -> 形状(10,)
为每个)
之后我使用 tensorflow 函数dataset = dataset.apply(tf.contrib.data.unbatch())
使我的数据集的形状为((200,200,3),(),(),(),())
. 从这里我现在与您分享代码:
编辑(更多代码):
.fit_generator()
以下代码是我的生成器函数的最后一行,它将从main()
shape_dataset = tf.shape(dataset) # shape ((10,200,200,3),(10,),(10,),(10,),(10,)) like I mentioned above
dataset = dataset.apply(tf.contrib.data.unbatch()) # shape ((200,200,3),(),(),(),()) like I mentioned a bove
dataset = dataset.shuffle(buffer_size = buffer_size)
dataset = dataset.batch(batch_size=batch_size)
dataset = dataset.repeat()
iterator_all = dataset.make_one_shot_iterator()
next_all = iterator_all.get_next()
with tf.Session() as sess:
while True:
try:
image, eye_color, hair_ color, labels = sess.run(next_all)
yield [image, eye_color, hair_ color], labels
except tf.errors.OutOfRangeError:
print('Finished')
break
这个张量现在将通过 keras.fit_generator()
函数输入到我的网络中。输入层如下所示:
input_image = Input(shape=(200, 200, 3))
input_eye_color = Input(shape=(1,), name='input_ec')
input_hair_color = Input(shape=(1,), name='input_hc')
现在我有一个问题:
10 从
((10,200,200,3),(10,),(10,),(10,),(10,))
哪里通过tf.contrib.data.unbatch())
函数?对我来说,感觉就像我失去了这 10 个值而只得到 1 个?该
fit_generator()
功能以批处理方式工作,但如何?听起来很愚蠢,我感觉我的网络在((200,200,3),(),(),(),())
一个迭代步骤中获得了形状数据。显然,它获取的数据类似于((8,10,200,200,3),(8,10,),(8,10,),(8, 10,),(8, 10,))
批处理大小为 8。
有人可以用形状向我解释这个问题吗?真的,我读了很多书,但我还是不明白。
谢谢你的帮助 :-)
解决方案
对于您在此处描述的模型
input_image = Input(shape=(200, 200, 3), name='input_img')
input_eye_color = Input(shape=(1,), name='input_ec')
input_hair_color = Input(shape=(1,), name='input_hc')
在 keras 中,fit_generator
接受以下两个输入之一:
- 张量列表
[bsize x 200 x 200 x 3, bsize x 1, bsize x 1]
张量字典
{'input_img':
bsize x 200 x 200 x3
,
'input_ec':bsize x 1
,'input_hc':bsize x 1
}
如您所见,这与您实际提供的内容完全不同。
推荐阅读
- pagination - 如何在 GraphQL 中强制分页
- excel - phpspreadsheet setFormatCode 性能问题
- python - 如何定义文件 .exe 读取的多个参数?
- hyperledger-fabric - 隐式策略评估失败 - 满足 1 个子策略,但此策略需要满足 2 个“背书”子策略
- python - 无法在 http://localhost:5000 访问 Dockerized Flask 应用程序
- c++ - `std::async` 用于 C++ 中的异步回复
- vhdl - 为什么这个 vhdl 代码会陷入无限循环?
- pybind11 - 在绑定自定义类型时,我还必须绑定自定义类型 API 中显示的所有其他类型吗?
- reactjs - React Material UI KeyboardDateTimePicjer 上的 maxDate 有问题
- javascript - 为什么 deepEqual 方法不适用于此代码?