首页 > 解决方案 > 在 tf.function 中使用 tf.keras.model.predict

问题描述

import tensorflow as tf
from tensorflow.keras.layers import Input, Multiply
from tensorflow.keras.models import Model

print(tf.__version__) # 2.1.1

def build_model():
    inputs = Input(shape=(42,))
    x = Multiply()([inputs, inputs])

    return Model(inputs, x)

x = tf.constant(3, shape=(1, 42), dtype='float32')

model = build_model()
print(model.predict(x)) # works fine

@tf.function
def f(x):
    x += 5
    return model.predict(x) # throws ValueError

print(f(x))

运行这段简单的代码会产生

    ValueError: When using data tensors as input to a model, you should specify the `steps` argument.

调用时model.predict,位于@tf.function.

为什么会这样?我不应该model.predict在里面使用@tf.function吗?还是以其他方式做?

如果我将 替换model.predict(input)为 just model(input),一切正常。另外,如果我steps=1按照错误提示添加参数,则会出现另一个错误

    ValueError: Unknown graph. Aborting.

标签: pythontensorflowkeras

解决方案


我认为 TF 版本 > 2.3 在它的输出错误中给出了解释,我用 TF 2.3.1 运行你的代码,它给出了错误:

RuntimeError: Detected a call to `Model.predict` inside a `tf.function`. 
`Model.predict is a high-level endpoint that manages its own `tf.function`. 
Please move the call to `Model.predict` outside of all enclosing `tf.function`s. 
Note that you can call a `Model` directly on `Tensor`s inside a `tf.function` like: `model(x)`.

推荐阅读