python - Keras 预测挂在 Heroku 上的 graph.as_default() 内
问题描述
我在烧瓶应用程序中加载了一个 Keras 模型并进行预测。在本地它工作正常,但在 heroku 上它无限期地挂model.predict()
在下面。
import tensorflow as tf
model = import_model()
graph = tf.get_default_graph()
....
def predict_sentiment(X, tokens):
for idx,_ in enumerate(tokens):
token = tokens[idx][0]
x_input = X[idx]
global graph
with graph.as_default():
yhat = model.predict([[x_input]], verbose=0) # <--- HANGS
...
最初我遇到了错误,Tensor Tensor(...) is not an element of this graph.
因为每个请求线程都在创建一个新的 tensorflow 会话。
所以我修复了,keras.backend.clear_session()
但是在每个请求上重新创建会话太慢了。
然后我从这篇文章中找到了我当前的解决方案。即创建一个全局会话并在每个请求中引用它。
在本地,这是快速且有效的。但是在heroku上,它卡在这条线上而没有抛出任何错误,我不知道为什么。
也试过:
在里面加载模型with graph.as_default():
。
import tensorflow as tf
graph = tf.get_default_graph()
with graph.as_default():
model = import_model()
我知道模型存在,因为我可以在同一个with...
块中打印模型。
print(model.layers[0].input)
=> <keras.engine.sequential.Sequential object at 0x7f829836f588>
.
并且还可以打印有关图层的信息。
print(model.layers[0].input)
=> Tensor("dense_1_input:0", shape=(?, 75684), dtype=float32)
但它不会预测。
解决方案
Keras 2.2.5 和 tensorflow 1.14.0 一起可以很好地克服这个错误
推荐阅读
- linux - Git push 发送并行副本到测试服务器
- python - PYTHON - 如何删除包含不需要数据的某些行
- sql - 如何在使用 case 语句添加新列后对查询中的行进行分组和计数
- c# - 为什么使用构建器而不是参数对象?
- python - 如何优化抽象语法树?
- firebase - Firestore OR 查询改进
- r - 如何将一列时间数据四舍五入到 r 中最接近的 15 分钟?
- gpu - 如何使用 Profiling+openCL+Sycl+DPCPP 测量 GPU 的执行时间
- python - 将多个 ImageView 项添加到 Qt.Window 以在 python 的一个窗口中获取多个图像
- ruby-on-rails - Rails ActiveRecord 嵌套 .create!使用 PSQL POINT 类型