首页 > 解决方案 > 将 Horovod 与 tf.keras 一起使用时如何从检查点恢复?

问题描述

注意:我使用的是 TF 2.1.0 和 tf.keras API。我在 0.18 和 0.19.2 之间的所有 Horovod 版本中都遇到了以下问题。

我们应该hvd.load_model()在从 tf.keras h5 检查点恢复时调用所有等级,还是只应该在等级 0 上调用它并让BroadcastGlobalVariablesCallback回调与其他工作人员共享这些权重?方法 1 是否不正确/无效,因为它会扰乱训练或产生与方法 2 不同的结果?

我目前正在训练一个带有一些 BatchNorm 层的基于 ResNet 的模型,如果我们只尝试在第一层加载模型(并在其他层上构建/编译模型),我们会遇到一个停滞的张量问题(https: //github.com/horovod/horovod/issues/1271)。但是,如果我们hvd.load_model在恢复时调用所有等级,训练开始正常恢复但它似乎立即发散,所以我很困惑是否在所有等级上加载检查点模型(带有hvd.load_model)会以某种方式导致训练发散?但同时,由于https://github.com/horovod/horovod/issues/1271,我们无法仅将其加载到 rank 0,导致 Batch Norm 在 horovod 中挂起。有没有人能够成功调用hvd.load_model仅在使用 BatchNorm tf.keras 层时排名为 0?有人可以在这里提供一些提示吗?

谢谢!

标签: pythontensorflowtensorflow2.0tf.kerashorovod

解决方案


根据这个:https://github.com/horovod/horovod/issues/120,这是解决方案:

You should also be able to specify optimizer via custom object:
model = keras.models.load_model('file.h5', custom_objects={
    'Adam': lambda **kwargs: hvd.DistributedOptimizer(keras.optimizers.Adam(**kwargs))
})

推荐阅读