首页 > 解决方案 > keras 和 tensorflow(后端错误)在图中未找到在 feed_devices 或 fetch_devices 中指定的 Tensor conv2d_1_input:0

问题描述

我正在使用 keras 和 tensorfoow,我对它完全陌生。我已经训练了我的模型,当我做出预测时,错误就出现了。这是我用于图像预测的代码

import numpy as np
from flask import Flask, request, jsonify, render_template
import numpy
from PIL import Image
import os
import tensorflow.keras
from werkzeug.utils import secure_filename
from keras.models import load_model

app = Flask(__name__)

model = load_model('traffic_classifier.h5')
model._make_predict_function()

@app.route('/')
def index():
    # Main page
    return render_template('index.html')

@app.route('/traffic')
def traffic():
    # Main page
    return render_template('traffic.html')

@app.route('/sleep')
def sleep():
    # Main page
    return render_template('sleep.html')

@app.route('/predict',methods=['POST'])
def predict():
    '''
    For rendering results on HTML GUI
    '''



    classes = { 1:'Speed limit (20km/h)',
            2:'Speed limit (30km/h)',      
            3:'Speed limit (50km/h)',       
            4:'Speed limit (60km/h)',      
            5:'Speed limit (70km/h)',    
            6:'Speed limit (80km/h)',      
            7:'End of speed limit (80km/h)',     
            8:'Speed limit (100km/h)',    
            9:'Speed limit (120km/h)',     
           10:'No passing',   
           11:'No passing veh over 3.5 tons',     
           12:'Right-of-way at intersection',     
           13:'Priority road',    
           14:'Yield',     
           15:'Stop',       
           16:'No vehicles',       
           17:'Veh > 3.5 tons prohibited',       
           18:'No entry',       
           19:'General caution',     
           20:'Dangerous curve left',      
           21:'Dangerous curve right',   
           22:'Double curve',      
           23:'Bumpy road',     
           24:'Slippery road',       
           25:'Road narrows on the right',  
           26:'Road work',    
           27:'Traffic signals',      
           28:'Pedestrians',     
           29:'Children crossing',     
           30:'Bicycles crossing',       
           31:'Beware of ice/snow',
           32:'Wild animals crossing',      
           33:'End speed + passing limits',      
           34:'Turn right ahead',     
           35:'Turn left ahead',       
           36:'Ahead only',      
           37:'Go straight or right',      
           38:'Go straight or left',      
           39:'Keep right',     
           40:'Keep left',      
           41:'Roundabout mandatory',     
           42:'End of no passing',      
           43:'End no passing veh > 3.5 tons' }




    if request. method == "POST":
        #image=request. form["fileupload"]

        f = request.files['file']

        # Save the file to ./uploads
        basepath = os.path.dirname(__file__)
        file_path = os.path.join(
            basepath, 'uploads', secure_filename(f.filename))
        f.save(file_path)  


    image = Image.open(file_path)
    image = image.resize((30,30))
    image = numpy.expand_dims(image, axis=0)
    image = numpy.array(image)

    pred = model.predict_classes([image])[0]

    sign = classes[pred+1]





    return render_template('traffic.html', prediction_text='This sign represents {}'.format(sign))


if __name__ == "__main__":
    app.run(debug=True)

我收到错误

tensorflow.python.framework.errors_impl.InvalidArgumentError tensorflow.python.framework.errors_impl.InvalidArgumentError: 在图中未找到在 feed_devices 或 fetch_devices 中指定的张量 conv2d_1_input:0

怎么办?

标签: pythontensorflowflaskkeras

解决方案


问题是 Flask 正在使用线程。这意味着对于每个请求,Flask 都会创建一个新线程。因此,您的模型在请求中不可见。

要解决此问题,您需要使模型成为全局会话的一部分,并在整个过程中使用。

解决方案可以在这里找到这个错误

from tensorflow.python.keras.backend import set_session
from tensorflow.python.keras.models import load_model

tf_config = some_custom_config
sess = tf.Session(config=tf_config)
graph = tf.get_default_graph()

# IMPORTANT: models have to be loaded AFTER SETTING THE SESSION for keras! 
# Otherwise, their weights will be unavailable in the threads after the session there has been set
set_session(sess)
model = load_model(...)

然后,在您的方法中:

def predict():
    ....
    global sess
    global graph
    with graph.as_default():
    set_session(sess)
    pred = model.predict_classes(...)
    ...

推荐阅读