首页 > 解决方案 > 如何在回调中获得批次中使用的确切训练示例?

问题描述

我在 Keras 中训练神经网络时遇到问题。每个 epoch,loss 都会稳步下降,达到 1e-9 左右,然后在 epoch 中间的某个地方(可能是任何地方),loss 会上升到 5e-5,最终稳定在每个 epoch 相同的最终 loss。我相信这是由于我的数据集中的一些脏数据导致模型无法训练超过某个点,尽管我真的不确定。

为了检验我的假设,我想创建一个自定义的 Keras 回调对象,它将确定一个批次后损失是否有足够大的跳跃,并指出哪个批次导致了跳跃。问题是batch提供给的参数keras.callbacks.Callback.on_batch_end只是批号实际上并不是该批中使用的训练示例。此外,logs传入的 dict 也只包含lossand acc

这意味着我实际上无法确定哪些数据导致了损失的跳跃。有没有办法可以确定导致每个时期跳跃的确切训练示例?有什么方法可以在回调中访问它吗?

标签: pythontensorflowkerasneural-networktensorflow2.0

解决方案


推荐阅读