python - 对 Model.fit 使用experimental_relax_shapes=True
问题描述
我有一个神经网络,通过 Model.fit() 进行训练需要很长时间。我总是得到以下信息:
WARNING:tensorflow:7 out of the last 12 calls to <function Model.make_train_function.<locals>.train_function at 0x7f128aee0ae8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
有没有办法告诉 Model.fit() 使用 experimental_relax_shapes=True ?
换句话说,我想写
@tf.function(experimental_relax_shapes=True)
在 Model.fit() 的定义之前。我怎样才能做到这一点?
如果我做这样的事情:
@tf.function(experimental_relax_shapes=True)
def fit(x):
return model.fit(x)
我明白了
RuntimeError: Detected a call to `Model.fit` inside a `tf.function`. `Model.fit is a high-level endpoint that manages its own `tf.function`. Please move the call to `Model.fit` outside of all enclosing `tf.function`s. Note that you can call a `Model` directly on `Tensor`s inside a `tf.function` like: `model(x)`.
所以我实际上想将 Model.fit 管理的 tf.function 更改为使用 experimental_relax_shapes=True 。
解决方案
我通过编写自己的训练循环来解决它,就像这里的第一个示例一样: https ://keras.io/guides/writing_a_training_loop_from_scratch/
奇怪的是,当我将 @tf.function(experimental_relax_shapes=True) 放在运行循环的函数的定义之前,训练步骤大约需要 30 到 40 秒。当我没有尝试时,他们花了不到一秒钟的时间。
推荐阅读
- c# - WCF RESTful 服务 - POST- WebFaultException 完整示例
- visual-studio - VS 2017 颜色主题编辑器 - 如何更改 Intellisense 自动完成背景颜色
- python - 根据其内容交换 numpy 二维数组中的数组
- css - 如何动态调整 vega 图表的大小以适应 CSS 网格
- python - %mprun 不断给出名称错误(导入模块的问题)
- post - 如何使用 Restassured 库从请求或响应对象中获取已在 POST 请求中传递的 body(json) 的内容?
- java - 如何使用子类外部的子引用访问与子变量同名的父类变量?
- marklogic - MarkLogic - Alternatives to using triggers
- asp.net-core - 带有 ASP.NET Core 2.2 的身份服务器 4
- c# - 如何解决字符串中只有一些字符的编码问题?