首页 > 解决方案 > 从 keras fit_generator() 到模型输入层的转换如何准确工作

问题描述

我正在处理图像数据和一些标量元数据(如头发颜色、眼睛颜色……)。我正在使用自己编写的生成器来使用 Keras.fit_generator()功能。

该过程如下所示:

在应用了一些数据增强之后,我得到了((10,200,200,3),(10,),(10,),(10,),(10,))我的数据集的形状(想象一下:我提取了 shape 的图像(200,200,3)并将其中的 10 个 -> 堆叠在一起(10,200,200,3)。因此,我将元数据复制了 10 次 -> 形状(10,)为每个)

之后我使用 tensorflow 函数dataset = dataset.apply(tf.contrib.data.unbatch())使我的数据集的形状为((200,200,3),(),(),(),()). 从这里我现在与您分享代码:

编辑(更多代码):

.fit_generator()以下代码是我的生成器函数的最后一行,它将从main()

shape_dataset = tf.shape(dataset) # shape ((10,200,200,3),(10,),(10,),(10,),(10,)) like I mentioned above
dataset = dataset.apply(tf.contrib.data.unbatch()) # shape ((200,200,3),(),(),(),()) like I mentioned a bove 
dataset = dataset.shuffle(buffer_size = buffer_size)
dataset = dataset.batch(batch_size=batch_size) 
dataset = dataset.repeat()
iterator_all = dataset.make_one_shot_iterator()
next_all = iterator_all.get_next()

with tf.Session() as sess:
    while True:
        try:
            image, eye_color, hair_ color, labels = sess.run(next_all)
            yield [image, eye_color, hair_ color], labels

        except tf.errors.OutOfRangeError:
            print('Finished')
            break

这个张量现在将通过 keras.fit_generator()函数输入到我的网络中。输入层如下所示:

input_image = Input(shape=(200, 200, 3))
input_eye_color = Input(shape=(1,), name='input_ec')
input_hair_color = Input(shape=(1,), name='input_hc')

现在我有一个问题:

  1. 10 从((10,200,200,3),(10,),(10,),(10,),(10,))哪里通过tf.contrib.data.unbatch())函数?对我来说,感觉就像我失去了这 10 个值而只得到 1 个?

  2. fit_generator()功能以批处理方式工作,但如何?听起来很愚蠢,我感觉我的网络在((200,200,3),(),(),(),())一个迭代步骤中获得了形状数据。显然,它获取的数据类似于 ((8,10,200,200,3),(8,10,),(8,10,),(8, 10,),(8, 10,))批处理大小为 8。

有人可以用形状向我解释这个问题吗?真的,我读了很多书,但我还是不明白。

谢谢你的帮助 :-)

标签: pythontensorflowkeras

解决方案


对于您在此处描述的模型

input_image = Input(shape=(200, 200, 3), name='input_img')
input_eye_color = Input(shape=(1,), name='input_ec')
input_hair_color = Input(shape=(1,), name='input_hc')

在 keras 中,fit_generator接受以下两个输入之一:

  1. 张量列表[bsize x 200 x 200 x 3, bsize x 1, bsize x 1]
  2. 张量字典

    {'input_img':bsize x 200 x 200 x3
    'input_ec':bsize x 1,'input_hc':bsize x 1}

如您所见,这与您实际提供的内容完全不同。


推荐阅读