首页 > 解决方案 > 加载 keras 模型并将其缓存在变量中,而无需重新加载

问题描述

在我的 Flask 应用程序开始时加载模型,然后在我的端点中使用它进行预测会导致错误

'ValueError: Tensor Tensor("dense/Softmax:0", shape=(?, 4), dtype=float32) 不是该图的元素。'

model = keras.models.load_model("model.h5")

@app.route("/predict", methods=["POST"])
def predict():
    json_data = request.get_json()

    variable = preparePredictionInput(
        [variable], alphabetDict, maxVariableLength)
    prediction = list(model.predict(variable, steps=1, verbose=1)[0])

但是每次调用预测端点时加载 keras 模型似乎工作得很好

@app.route("/predict", methods=["POST"])
def predict():
    json_data = request.get_json()
    model = keras.models.load_model("model.h5")

    variable = preparePredictionInput(
        [variable], alphabetDict, maxVariableLength)
    prediction = list(model.predict(variable, steps=1, verbose=1)[0]) 

有没有办法来解决这个问题?这从根本上降低了每次都必须重新加载模型的性能。

标签: pythontensorflowkeras

解决方案


似乎您的模型变量不是全局的。看看下面的代码:

def init():
  global model
  model = lkeras.models.load_model("model.h5")

@app.route("/predict", methods=["POST"])
def predict():
    json_data = request.get_json()
    variable = preparePredictionInput([variable], alphabetDict, maxVariableLength)
    prediction = list(model.predict(variable, steps=1, verbose=1)[0])


if __name__ == "__main__":
    init()
    app.run()

推荐阅读