首页 > 解决方案 > 对 Model.fit 使用experimental_relax_shapes=True

问题描述

我有一个神经网络,通过 Model.fit() 进行训练需要很长时间。我总是得到以下信息:

WARNING:tensorflow:7 out of the last 12 calls to <function Model.make_train_function.<locals>.train_function at 0x7f128aee0ae8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.

有没有办法告诉 Model.fit() 使用 experimental_relax_shapes=True ?

换句话说,我想写

@tf.function(experimental_relax_shapes=True)

在 Model.fit() 的定义之前。我怎样才能做到这一点?

如果我做这样的事情:

@tf.function(experimental_relax_shapes=True)
def fit(x):
  return model.fit(x)

我明白了

RuntimeError: Detected a call to `Model.fit` inside a `tf.function`. `Model.fit is a high-level endpoint that manages its own `tf.function`. Please move the call to `Model.fit` outside of all enclosing `tf.function`s. Note that you can call a `Model` directly on `Tensor`s inside a `tf.function` like: `model(x)`.

所以我实际上想将 Model.fit 管理的 tf.function 更改为使用 experimental_relax_shapes=True 。

标签: pythontensorflowkeras

解决方案


我通过编写自己的训练循环来解决它,就像这里的第一个示例一样: https ://keras.io/guides/writing_a_training_loop_from_scratch/

奇怪的是,当我将 @tf.function(experimental_relax_shapes=True) 放在运行循环的函数的定义之前,训练步骤大约需要 30 到 40 秒。当我没有尝试时,他们花了不到一秒钟的时间。


推荐阅读