首页 > 解决方案 > 使用 TensorFlow 数据集 from_generator() 使用自定义生成器和 ImageDataGenerator 创建多输入/输出

问题描述

我正在尝试扩展我的模型,该模型使用“集群损失”扩展,到目前为止,该实现在 MNIST 上有效,但我希望从真实数据集的数据增强和多处理中受益。

简而言之,该网络遵循“中心损失”完成的工作,这有点像连体网络。架构的重要部分是模型有 2 个输入和 2 个输出。因此,我实现了一个自定义生成器,以便为模型提供如下内容:

def my_generator(stop):
    i = 0
    while i < stop:
        batch = train_gen.next()
        img = batch[0]
        labels = batch[1]
        labels_size = np.shape(labels)
        cluster = np.zeros(labels_size)
        x = [img, labels]
        y = [labels, cluster]

        yield x, y
        i += 1

它调用定义如下的生成器(“train_gen”):

generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, horizontal_flip=True)
train_gen = generator.flow_from_dataframe(df, x_col='img_path', y_col='label',
                                          class_mode='categorical',
                                          target_size=(32, 32),
                                          batch_size=batch_size)

如果我在 fit 函数中只设置一个工人,则生成器可以工作。但显然它非常缓慢......所以我尝试使用来自 Tensorflow (tf.data.Dataset.from_generator) 的推荐 tf.Data 来拟合我的模型,但将其设置如下,

ds = tf.data.Dataset.from_generator(my_generator,
                                    args=[num_iter],
                                    output_types=([tf.float32, tf.float32], [tf.float32, tf.float32]))

我收到以下错误:

TypeError: Cannot convert value [tf.float32, tf.float32] to a Tensorflow DType. 

从那里,我尝试了多种方法,遵循这篇文章

例如,尝试返回元组而不是数组:

x = (img, labels)
y = (labels, cluster)

但我得到了:

ValueError: as_list() is not defined on an unknown TensorShape 

这个事情谁有经验?我不确定是否理解该错误,并且我认为我可以更改“output_types”参数,但 TensorFlow 没有“列表”或“元组”DType 参数。

这是我的代码的链接,该代码从 cifar10 构建了一个小型图像数据集以提供玩具模型。

标签: tensorflowgeneratordata-augmentation

解决方案


我不认为您的发电机按您预期的那样工作。每次调用它都会设置 i=0。之后的代码

yield x, y
i += 1

i += 1 从不执行。将打印语句如下

yield x, y
i += 1
print ('the value of i is ',i)

你会看到它永远不会执行。

如果你执行以上是真的

x,y=next(my_generator(2))

这就是生成器的使用方式。但是,如果您执行

x,y=my_generator(2)

然后 i += 1 语句确实执行。通常对于生成器,您可以将它们与 next(my_generator) 一起使用。model.fit 我相信通过在您指定的生成器上使用 next() 可以获得下一批。


推荐阅读