python - 如何使用烧瓶为 keras 模型提供推理服务?
问题描述
我有一个yolo v3 keras模型的推理对象检测代码
#! /usr/bin/env python
import os
import argparse
import json
import cv2
from utils.utils import get_yolo_boxes, makedirs
from utils.bbox import draw_boxes
from keras.models import load_model
from tqdm import tqdm
import numpy as np
import flask
import io
from PIL import Image
from keras.preprocessing.image import img_to_array
config_path = "config.json"
input_path = "test.jpg"
output_path = "output"
with open(config_path) as config_buffer:
config = json.load(config_buffer)
makedirs(output_path)
net_h, net_w = 416, 416
obj_thresh, nms_thresh = 0.5, 0.45
os.environ['CUDA_VISIBLE_DEVICES'] = config['train']['gpus']
infer_model = load_model(config['train']['saved_weights_name'])
image = cv2.imread(input_path)
# predict the bounding boxes
boxes = get_yolo_boxes(infer_model, [image], net_h, net_w, config['model']['anchors'], obj_thresh, nms_thresh)[0]
# draw bounding boxes on the image using labels
_,outputs = draw_boxes(image, boxes, config['model']['labels'], obj_thresh)
print(outputs)
# write the image with bounding boxes to file
cv2.imwrite(output_path + input_path.split('/')[-1], np.uint8(image))
这工作得很好,在终端中给出了预期的输出类和坐标
{'classes':'人 99.97%','X2':'389','X1':'174','Y1':'8','Y2':'8'}
但是,当我通过使用官方 keras 转换为烧瓶文档将上述代码转换为基于服务的 REST api 时,如下所示:
#! /usr/bin/env python
import os
import argparse
import json
import cv2
from utils.utils import get_yolo_boxes, makedirs
from utils.bbox import draw_boxes
from keras.models import load_model
from tqdm import tqdm
import numpy as np
import flask
import io
from PIL import Image
from keras.preprocessing.image import img_to_array
config_path = "config.json"
input_path = "test.jpg"
output_path = "output"
with open(config_path) as config_buffer:
config = json.load(config_buffer)
makedirs(output_path)
net_h, net_w = 416, 416
obj_thresh, nms_thresh = 0.5, 0.45
app = flask.Flask(__name__)
os.environ['CUDA_VISIBLE_DEVICES'] = config['train']['gpus']
infer_model = load_model(config['train']['saved_weights_name'])
def prepare_image(image_path):
image = cv2.imread(image_path)
return image
@app.route("/predict", methods=["POST"])
def predict():
# initialize the data dictionary that will be returned from the
# view
data = {"success": False}
# ensure an image was properly uploaded to our endpoint
if flask.request.method == "POST":
if flask.request.files.get("image"):
# read the image in PIL format
image = flask.request.files["image"].read()
image = Image.open(io.BytesIO(image))
# preprocess the image and prepare it for classification
image = img_to_array(image)
boxes = get_yolo_boxes(infer_model, [image], net_h, net_w, config['model']['anchors'], obj_thresh, nms_thresh)[0]
_,outputs = draw_boxes(image, boxes, config['model']['labels'], obj_thresh)
data.append(outputs)
print(data)
# indicate that the request was a success
data["success"] = True
# return the data dictionary as a JSON response
return flask.jsonify(data)
if __name__ == "__main__":
print(("* Loading Keras model and Flask starting server..."
"please wait until server has fully started"))
app.run()
在 5000 端口成功运行
但是当我试图通过 POST api 使用
curl -X POST -F image=@test.jpg 'http://localhost:5000/predict'
给出这个错误
raise ValueError("Tensor %s is not an element of this graph." % obj) ValueError: Tensor Tensor("conv2d_59/BiasAdd:0", shape=(?, ?, ?,
255), dtype=float32) 不是该图的元素。127.0.0.1 - - [15/Aug/2019 15:11:23] “POST /predict HTTP/1.1”500 -
我不明白为什么相同的预测函数在没有烧瓶的情况下也能工作,但会出错。
解决方案
我遇到了同样的问题,这是一个 keras 问题。model._make_predict_function()
主要似乎是在加载经过训练的模型后立即添加异步事件处理程序时触发的,这对我有用。例如,
from keras.models import load_model
model=load_model('yolo.h5')
model._make_predict_function()
另一种对其他人有效的方法是使用图形并在上下文中进行推理,例如:
global graph
graph = tf.get_default_graph()
with graph.as_default():
res = model.predict()
有关更多见解,请参阅以下链接:
推荐阅读
- sql - 如何为表 1 中的每一行从表 2 中获取一行(有 blob 列)?
- macos - OS X 自定义登录验证
- java - 在 Heroku 上的 Java JAX-RS API 上启用 CORS 请求
- javascript - 几秒钟后,Jquery animation() 响应非常缓慢
- c# - 如何从 Controller 重启 BackgroundService
- odata - SAPUi5 Odatamodel V4 setHeaders 缺失
- java - 创建通用数组作为静态字段
- apache-kafka - kafka 本地状态存储/更改日志中的保留时间
- c++ - 优化特征表达
- python - 如何将泡菜文件加载到numpy数组