首页 > 解决方案 > Keras 预测挂在 Heroku 上的 graph.as_default() 内

问题描述

我在烧瓶应用程序中加载了一个 Keras 模型并进行预测。在本地它工作正常,但在 heroku 上它无限期地挂model.predict()在下面。

import tensorflow as tf
model = import_model()
graph = tf.get_default_graph()
....

def predict_sentiment(X, tokens):

    for idx,_ in enumerate(tokens):

        token = tokens[idx][0]
        x_input = X[idx]

        global graph
        with graph.as_default():
            yhat = model.predict([[x_input]], verbose=0) # <--- HANGS
        ...

最初我遇到了错误,Tensor Tensor(...) is not an element of this graph.因为每个请求线程都在创建一个新的 tensorflow 会话。

所以我修复了,keras.backend.clear_session()但是在每个请求上重新创建会话太慢了。

然后我从这篇文章中找到了我当前的解决方案。即创建一个全局会话并在每个请求中引用它。

在本地,这是快速且有效的。但是在heroku上,它卡在这条线上而没有抛出任何错误,我不知道为什么。

也试过:

在里面加载模型with graph.as_default():

import tensorflow as tf
graph = tf.get_default_graph()

with graph.as_default():
    model = import_model()

我知道模型存在,因为我可以在同一个with...块中打印模型。

print(model.layers[0].input)

=> <keras.engine.sequential.Sequential object at 0x7f829836f588>.

并且还可以打印有关图层的信息。

print(model.layers[0].input)

=> Tensor("dense_1_input:0", shape=(?, 75684), dtype=float32)

但它不会预测。

标签: pythontensorflowherokuflaskkeras

解决方案


Keras 2.2.5 和 tensorflow 1.14.0 一起可以很好地克服这个错误


推荐阅读