python - 在 tf.function 中使用 tf.keras.model.predict
问题描述
import tensorflow as tf
from tensorflow.keras.layers import Input, Multiply
from tensorflow.keras.models import Model
print(tf.__version__) # 2.1.1
def build_model():
inputs = Input(shape=(42,))
x = Multiply()([inputs, inputs])
return Model(inputs, x)
x = tf.constant(3, shape=(1, 42), dtype='float32')
model = build_model()
print(model.predict(x)) # works fine
@tf.function
def f(x):
x += 5
return model.predict(x) # throws ValueError
print(f(x))
运行这段简单的代码会产生
ValueError: When using data tensors as input to a model, you should specify the `steps` argument.
调用时model.predict
,位于@tf.function
.
为什么会这样?我不应该model.predict
在里面使用@tf.function
吗?还是以其他方式做?
如果我将 替换model.predict(input)
为 just model(input)
,一切正常。另外,如果我steps=1
按照错误提示添加参数,则会出现另一个错误
ValueError: Unknown graph. Aborting.
解决方案
我认为 TF 版本 > 2.3 在它的输出错误中给出了解释,我用 TF 2.3.1 运行你的代码,它给出了错误:
RuntimeError: Detected a call to `Model.predict` inside a `tf.function`.
`Model.predict is a high-level endpoint that manages its own `tf.function`.
Please move the call to `Model.predict` outside of all enclosing `tf.function`s.
Note that you can call a `Model` directly on `Tensor`s inside a `tf.function` like: `model(x)`.
推荐阅读
- python - 增量实体解析/记录链接的最佳(PostgreSQL?)数据模型和处理
- reactjs - 使用 localhost 而不是 localhost:8081 我有 nginx、webpack、react stack 如何?
- python-multiprocessing - 为什么某些代码在 python 进程中的函数中不起作用?
- firebase - 如何为我的数据库编写 firebase 规则,以防止某些用户能够写入?
- php - 使用 html 循环通过 github 图像目录
- javascript - 在 Redux React 中将对象添加到嵌套数组
- apache-spark - spark sql时间戳数据类型是否实际存储时区?
- mongodb - Mongo db聚合查询
- java - Java高阶函数面临的问题。没有从高阶函数中获得正确的好处
- python - Django Rest Api 将成员(用户)添加到类模型