首页 > 解决方案 > 如何在 TensorFlow 云中调用 apply_gradients()?

问题描述

我正在尝试在 Google Cloud Platform 中训练我的 tensorflow 模型(具有自定义训练循环)。当我将训练脚本提交到云端时,作业失败。日志说:

apply_gradients() cannot be called in cross-replica context. Use tf.distribute.Strategy.run to enter replica context.

如果我尝试按照它的建议去做(明确指出分布式训练策略,我会按照本教程进行操作),那么我会收到错误消息:

"RuntimeError: Mixing different tf.distribute.Strategy objects: <tensorflow.python.distribute.one_device_strategy.OneDeviceStrategy object at 0x7fca824ddb38> is not <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fc90eb5bf6

即使我使用 OneDeviceStrategy,我也会得到相同的错误,但参考了 OneDeviceStrategy 的两个实例。

关于什么可能是错误的以及如何解决它的任何想法?是第一次尝试在 Google Cloud 中运行自定义 tensorflow 模型。

标签: tensorflowgoogle-cloud-platform

解决方案


推荐阅读