首页 > 解决方案 > 在增强训练期间无法在 Keras iterator.py 中的断点处停止

问题描述

我创建了数据生成器类的两个实例,从 keras 序列类扩展,一个用于训练,一个用于验证数据。但是,在我的源代码级别,我只能看到验证生成器在每个时期之间重新迭代。我看不到训练生成器。因此,我无法验证训练数据的扩充是否符合我的意图。在这些代码片段中,aug 是一个参数字典,这些参数会转到我的 myDataGen 序列扩展中的 keras ImageDataGenerator 实例。我通常不会增加验证数据,但这就是我偶然发现这个难题的方式:

    aug = dict(fill_mode='nearest',
                        rotation_range=10,
                        zoom_range=0.3,
                        width_shift_range=0.1,
                        height_shift_range=0.1
                        )
    training_datagen = myDataGen(Xdata_train,ydata_train,**aug)
    validation_datagen = myDataGen(Xdata_test,ydata_test,**aug)

    history = model.fit(training_datagen,
                                validation_data=validation_datagen,
                                validation_batch_size=16,
                                epochs=50,
                                shuffle=False,
                                )

一切正常,我得到了很好的结果,但我只是想确定增强。因此,我可以通过浏览 keras 中的各种函数来收集我编写的数据生成器填充了一个较低级别的 tensorflow 数据集,然后每个 epoch 进行迭代。我只是看不到张量流数据集是如何在每个时期增加的。

现在,我还意外发现,虽然 fit 方法不支持验证数据的生成器,但它确实有效,并且具有我希望训练生成器具有的有趣功能,即重新读取数据从磁盘,以便它在我自己的源代码级别重新增强。

最重要的是,我可以看到 tensorflow Dataset.cache() 方法可能在第一个 epoch 之后将我的训练数据集存储在内存中的提示。我可以以某种方式 uncache() 它来强制重新读取和重新增强,或者有人可以指出张量流数据集在迭代时如何调用增强方法?

唔。这个线程TF Dataset API for Image augmentation清楚地表明,直接在 tensorflow Dataset API 中编写增强方法很容易,但是贡献者在评论中写道,您不能在 tf.data.Dataset 上使用 keras.ImageDataGenerator。但我可以在 keras 模块中清楚地看到,我的 keras 数据集正在“适应”到底层的 tf.data.Dataset 中。如果此评论属实,它将解释为什么我似乎无法打破 ImageDataGenerator 对我的训练数据的扩充。但这怎么可能是真的呢?

标签: pythonkerasdata-generation

解决方案


我的困惑来自初学者的错误,即忽略了一个事实,即在将 keras 源代码编译到 gpu 之后,它当然不能在 keras 源代码的级别上中断。但是从这种混乱中产生的有趣之处在于,您可以为验证数据编写一个 keras 生成器,并为每个 epoch 打破它,因为它显然没有编译到 gpu 上......因为 keras 不支持验证数据的生成器!只是生成器处理得很好,没有运行时错误。一个不起眼的发现,但希望它可以帮助某人。


推荐阅读