首页 > 解决方案 > 在 tensorflow 会话中处理多个文件

问题描述

func_name(loc, id , mn):    
    with detection_graph.as_default():
         with tf.compat.v1.Session(graph=detection_graph) as sess:
                #tf.initialize_all_variables().run()

                while cap.isOpened():
                    ret, image_np = cap.read()
                    print(ret)

                    if not ret:
                        break
                    # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
                    image_np_expanded = np.expand_dims(image_np, axis=0)
                    # Extract image tensor
        sess.close()

我使用 func_name(location, id, model_name)上述普通对象检测会话代码发送一个文件以进行处理,然后保存并返回,但是在我尝试在不退出程序的情况下发送另一个文件后,我得到了第一帧,然后什么也没发生,即处理不会像处理第一个文件后的所有文件的第一个文件。

如何在不退出代码并重新启动的情况下处理多个文件?我试过了initialize variablessess.close()但它仍然不起作用。多个文件使用flask.

UPDATE 1

从不同的detect_func()脚本调用它,从它获取所需的所有参数。

import numpy as np
import os

import six.moves.urllib as urllib
import sys
sys.path.append("..")
import tarfile
import tensorflow as tf
import zipfile
import cv2

from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
# from models.research import *
#from models.research.object_detection.utils import label_map_util
from codes.models.research.object_detection.utils import visualization_utils as vis_util
from codes.models.research.object_detection.utils import label_map_util


#cap = cv2.VideoCapture(0)  # Change only if you have more than one webcams

# What model to download.
# Models can bee found here: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
MODEL_NAME = 'ssd_inception_v2_coco_2017_11_17'
MODEL_FILE = MODEL_NAME + '.tar.gz'
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'

# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'

# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('/tensorflow/models/research/object_detection/data', 'mscoco_label_map.pbtxt')

# Number of classes to detect
NUM_CLASSES = 90

# Download Model
if not os.path.exists(os.path.join(os.getcwd(), MODEL_FILE)):
    print("Downloading model")
    opener = urllib.request.URLopener()
    opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
    tar_file = tarfile.open(MODEL_FILE)
    for file in tar_file.getmembers():
        file_name = os.path.basename(file.name)
        if 'frozen_inference_graph.pb' in file_name:
            tar_file.extract(file, os.getcwd())


# Load a (frozen) Tensorflow model into memory.
detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.compat.v1.GraphDef()
    with tf.io.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')


# Loading label map
# Label maps map indices to category names, so that when our convolution network predicts `5`, we know that this corresponds to `airplane`.  Here we use internal utility functions, but anything that returns a dictionary mapping integers to appropriate string labels would be fine
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(
    label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)


# Helper code
def load_image_into_numpy_array(image):
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape(
        (im_height, im_width, 3)).astype(np.uint8)



sess = tf.compat.v1.Session(graph=detection_graph)


def dectect_func(location, id, model_name):
    VID_SAVE_PATH = '/tensorflow/downloads/'
    # Define the video stream
    cap = cv2.VideoCapture(location)  # Change only if you have more than one webcams
    fourcc = cv2.VideoWriter_fourcc('M','J','P','G')
    out = cv2.VideoWriter(VID_SAVE_PATH + id + '.avi',fourcc, 20.0, (640,480))
    while True:
        # Read frame from camera
        ret, image_np = cap.read()
        if not ret:
            break
        # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
        image_np_expanded = np.expand_dims(image_np, axis=0)
        # Extract image tensor
        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
        # Extract detection boxes
        boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
        # Extract detection scores
        scores = detection_graph.get_tensor_by_name('detection_scores:0')
        # Extract detection classes
        classes = detection_graph.get_tensor_by_name('detection_classes:0')
        # Extract number of detectionsd
        num_detections = detection_graph.get_tensor_by_name(
            'num_detections:0')
        # Actual detection.
        (boxes, scores, classes, num_detections) = sess.run(
            [boxes, scores, classes, num_detections],
            feed_dict={image_tensor: image_np_expanded})
        # Visualization of the results of a detection.
        '''
        vis_util.visualize_boxes_and_labels_on_image_array(
            image_np,
            np.squeeze(boxes),
            np.squeeze(classes).astype(np.int32),
            np.squeeze(scores),
            category_index,
            use_normalized_coordinates=True,
            line_thickness=8)
        '''
        print(num_detections)

        # Display output
        cv2.imshow('object detection', cv2.resize(image_np, (800, 600)))

        if cv2.waitKey(25) & 0xFF == ord('q'):
            print("pressed q on window")
            cv2.destroyAllWindows()
            break

    cap.release()
    cv2.destroyAllWindows()




# Detection

更新 2:

def process_video():

conn = sqlite3.connect(
    'db/abc.sqlite')
cur = conn.cursor()
cur.execute(
    "SELECT id, location, model_name FROM uploads WHERE isProcessed=0 order by datetime DESC")

id, location, model_name = cur.fetchone()
print(id, location, model_name)
if not (id, location):
    cur.execute(
    "SELECT id, location FROM uploads WHERE isProcessed=0 order by datetime DESC")
func_name(location, id, model_name)

cur.execute("UPDATE uploads SET isProcessed=1  WHERE id='"+id+"'")
conn.commit()
conn.close()
print('yes')

update 3

True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
False
yes
File saved successfully
9da51fde-5deb-4f78-8f58-13661723daf8 uploads/output.mp4 ssd_inception_v2_coco_2017_11_17
/tensorflow/ssd_inception_v2_coco_2017_11_17/frozen_inference_graph.pb
True

在这里,我输出 True 是否获得帧,最后一个 True 是我传递的第二个文件,您可以看到该文件的位置和内容。它只需要第一帧,什么都没有发生。

标签: pythontensorflowflask

解决方案


以下修改对我有用,让我们重新使用检测循环:


sess = tf.compat.v1.Session(graph=detection_graph)


def dectect_func(cap):
    while True:
        # Read frame from camera
        ret, image_np = cap.read()
        # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
        image_np_expanded = np.expand_dims(image_np, axis=0)
        # Extract image tensor
        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
        # Extract detection boxes
        boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
        # Extract detection scores
        scores = detection_graph.get_tensor_by_name('detection_scores:0')
        # Extract detection classes
        classes = detection_graph.get_tensor_by_name('detection_classes:0')
        # Extract number of detectionsd
        num_detections = detection_graph.get_tensor_by_name(
            'num_detections:0')
        # Actual detection.
        (boxes, scores, classes, num_detections) = sess.run(
            [boxes, scores, classes, num_detections],
            feed_dict={image_tensor: image_np_expanded})
        # Visualization of the results of a detection.
        '''
        vis_util.visualize_boxes_and_labels_on_image_array(
            image_np,
            np.squeeze(boxes),
            np.squeeze(classes).astype(np.int32),
            np.squeeze(scores),
            category_index,
            use_normalized_coordinates=True,
            line_thickness=8)
        '''
        print(num_detections)

        # Display output
        cv2.imshow('object detection', cv2.resize(image_np, (800, 600)))

        if cv2.waitKey(25) & 0xFF == ord('q'):
            print("pressed q on window")
            cv2.destroyAllWindows()
            break


dectect_func(cap)
dectect_func(cap)

我没有克隆 tf object_detectionrepo,所以在这里我没有可视化。但是num_detections当我旋转相机时,我会看到改变。

编辑:我认为opencv保存文件有问题。试试这个代码:

def dectect_func(location, id):
    print('processing: ', location, id)
    VID_SAVE_PATH = 'out'
    # Define the video stream
    cap = cv2.VideoCapture(location)  # Change only if you have more than one webcams
    fourcc = cv2.VideoWriter_fourcc('M','J','P','G')
    out = cv2.VideoWriter(VID_SAVE_PATH + id + '.avi', fourcc, 20.0, (640,480)) #cv2.VideoWriter(VID_SAVE_PATH + id + '.avi',fourcc, 20.0, (640,480))
    while True:
        # Read frame from camera
        ret, image_np = cap.read()
        if not ret:
            break
        # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
        image_np_expanded = np.expand_dims(image_np, axis=0)
        # Extract image tensor
        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
        # Extract detection boxes
        boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
        # Extract detection scores
        scores = detection_graph.get_tensor_by_name('detection_scores:0')
        # Extract detection classes
        classes = detection_graph.get_tensor_by_name('detection_classes:0')
        # Extract number of detectionsd
        num_detections = detection_graph.get_tensor_by_name(
            'num_detections:0')
        # Actual detection.
        (boxes, scores, classes, num_detections) = sess.run(
            [boxes, scores, classes, num_detections],
            feed_dict={image_tensor: image_np_expanded})
        # Visualization of the results of a detection.

        vis_util.visualize_boxes_and_labels_on_image_array(
            image_np,
            np.squeeze(boxes),
            np.squeeze(classes).astype(np.int32),
            np.squeeze(scores),
            category_index,
            use_normalized_coordinates=True,
            line_thickness=8)

        print(num_detections)

        # otherwise there will be no file saved if resolution mismatch
        frame = cv2.resize(image_np, (640,480), cv2.INTER_CUBIC)

        out.write(frame)



    cap.release()
    out.release()
    cv2.destroyAllWindows()



# Detection
dectect_func('small.mp4','0')
dectect_func('small.mp4','1')


推荐阅读