首页 > 解决方案 > 使用 fit 恢复训练会将批处理步骤重置为 0

问题描述

我最近重写了我的 TensorFlow 2 自定义训练循环来使用fit,使用ModelCheckpoint回调来管理我之前在循环中手动执行的检查点。这一切都很好,但我有一个一直在努力解决的问题:在正确的批处理步骤中恢复训练。我将TensorBoard回调与 一起使用update_freq=50,当我恢复训练(从保存的检查点加载权重)时,我通常会看到如下内容:

在此处输入图像描述

以上结果来自 2 次运行,单个 epoch 包含 250 个批处理步骤(只是一个玩具数据集)。第一行是 2 个 epoch 的第一次运行,在 500 步后结束(摘要每 50 步更新一次,但不是在最后一步之后,大概,所以在 500 处的最后一步丢失了)。中间的直线只是绘制到第二次运行开始的图形线,由底线表示。我运行了 3 个 epoch,因此有 750 个批处理步骤 (250 * 3)。

问题是每次重新开始训练时步数都会从 0 开始。我怎样才能解决这个问题?大概是 TensorBoard 回调从 0 开始跟踪每个时期的步骤... fit 方法有一个initial_epoch参数,我用它在正确的时期重新开始,但是是否可以跟踪全局批处理步骤?我global_step在旧的(TF2 之前的)代码中看到过,是用来实现这个的吗?

标签: tensorflowmachine-learningkerastensorflow2.0tensorboard

解决方案


好的,似乎管理这个的方法是通过runs的概念,这相当于为不同的培训课程创建单独的日志子目录 - 如这个 Colab notebook 所示


推荐阅读