python - 在 tensorflow 会话中处理多个文件
问题描述
func_name(loc, id , mn):
with detection_graph.as_default():
with tf.compat.v1.Session(graph=detection_graph) as sess:
#tf.initialize_all_variables().run()
while cap.isOpened():
ret, image_np = cap.read()
print(ret)
if not ret:
break
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image_np, axis=0)
# Extract image tensor
sess.close()
我使用
func_name(location, id, model_name)
上述普通对象检测会话代码发送一个文件以进行处理,然后保存并返回,但是在我尝试在不退出程序的情况下发送另一个文件后,我得到了第一帧,然后什么也没发生,即处理不会像处理第一个文件后的所有文件的第一个文件。
如何在不退出代码并重新启动的情况下处理多个文件?我试过了initialize variables
,sess.close()
但它仍然不起作用。多个文件使用flask
.
UPDATE 1
从不同的detect_func()
脚本调用它,从它获取所需的所有参数。
import numpy as np
import os
import six.moves.urllib as urllib
import sys
sys.path.append("..")
import tarfile
import tensorflow as tf
import zipfile
import cv2
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
# from models.research import *
#from models.research.object_detection.utils import label_map_util
from codes.models.research.object_detection.utils import visualization_utils as vis_util
from codes.models.research.object_detection.utils import label_map_util
#cap = cv2.VideoCapture(0) # Change only if you have more than one webcams
# What model to download.
# Models can bee found here: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
MODEL_NAME = 'ssd_inception_v2_coco_2017_11_17'
MODEL_FILE = MODEL_NAME + '.tar.gz'
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'
# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('/tensorflow/models/research/object_detection/data', 'mscoco_label_map.pbtxt')
# Number of classes to detect
NUM_CLASSES = 90
# Download Model
if not os.path.exists(os.path.join(os.getcwd(), MODEL_FILE)):
print("Downloading model")
opener = urllib.request.URLopener()
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
file_name = os.path.basename(file.name)
if 'frozen_inference_graph.pb' in file_name:
tar_file.extract(file, os.getcwd())
# Load a (frozen) Tensorflow model into memory.
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.compat.v1.GraphDef()
with tf.io.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
# Loading label map
# Label maps map indices to category names, so that when our convolution network predicts `5`, we know that this corresponds to `airplane`. Here we use internal utility functions, but anything that returns a dictionary mapping integers to appropriate string labels would be fine
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(
label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
# Helper code
def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8)
sess = tf.compat.v1.Session(graph=detection_graph)
def dectect_func(location, id, model_name):
VID_SAVE_PATH = '/tensorflow/downloads/'
# Define the video stream
cap = cv2.VideoCapture(location) # Change only if you have more than one webcams
fourcc = cv2.VideoWriter_fourcc('M','J','P','G')
out = cv2.VideoWriter(VID_SAVE_PATH + id + '.avi',fourcc, 20.0, (640,480))
while True:
# Read frame from camera
ret, image_np = cap.read()
if not ret:
break
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image_np, axis=0)
# Extract image tensor
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
# Extract detection boxes
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
# Extract detection scores
scores = detection_graph.get_tensor_by_name('detection_scores:0')
# Extract detection classes
classes = detection_graph.get_tensor_by_name('detection_classes:0')
# Extract number of detectionsd
num_detections = detection_graph.get_tensor_by_name(
'num_detections:0')
# Actual detection.
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
# Visualization of the results of a detection.
'''
vis_util.visualize_boxes_and_labels_on_image_array(
image_np,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
use_normalized_coordinates=True,
line_thickness=8)
'''
print(num_detections)
# Display output
cv2.imshow('object detection', cv2.resize(image_np, (800, 600)))
if cv2.waitKey(25) & 0xFF == ord('q'):
print("pressed q on window")
cv2.destroyAllWindows()
break
cap.release()
cv2.destroyAllWindows()
# Detection
更新 2:
def process_video():
conn = sqlite3.connect(
'db/abc.sqlite')
cur = conn.cursor()
cur.execute(
"SELECT id, location, model_name FROM uploads WHERE isProcessed=0 order by datetime DESC")
id, location, model_name = cur.fetchone()
print(id, location, model_name)
if not (id, location):
cur.execute(
"SELECT id, location FROM uploads WHERE isProcessed=0 order by datetime DESC")
func_name(location, id, model_name)
cur.execute("UPDATE uploads SET isProcessed=1 WHERE id='"+id+"'")
conn.commit()
conn.close()
print('yes')
update 3
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
False
yes
File saved successfully
9da51fde-5deb-4f78-8f58-13661723daf8 uploads/output.mp4 ssd_inception_v2_coco_2017_11_17
/tensorflow/ssd_inception_v2_coco_2017_11_17/frozen_inference_graph.pb
True
在这里,我输出 True 是否获得帧,最后一个 True 是我传递的第二个文件,您可以看到该文件的位置和内容。它只需要第一帧,什么都没有发生。
解决方案
以下修改对我有用,让我们重新使用检测循环:
sess = tf.compat.v1.Session(graph=detection_graph)
def dectect_func(cap):
while True:
# Read frame from camera
ret, image_np = cap.read()
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image_np, axis=0)
# Extract image tensor
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
# Extract detection boxes
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
# Extract detection scores
scores = detection_graph.get_tensor_by_name('detection_scores:0')
# Extract detection classes
classes = detection_graph.get_tensor_by_name('detection_classes:0')
# Extract number of detectionsd
num_detections = detection_graph.get_tensor_by_name(
'num_detections:0')
# Actual detection.
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
# Visualization of the results of a detection.
'''
vis_util.visualize_boxes_and_labels_on_image_array(
image_np,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
use_normalized_coordinates=True,
line_thickness=8)
'''
print(num_detections)
# Display output
cv2.imshow('object detection', cv2.resize(image_np, (800, 600)))
if cv2.waitKey(25) & 0xFF == ord('q'):
print("pressed q on window")
cv2.destroyAllWindows()
break
dectect_func(cap)
dectect_func(cap)
我没有克隆 tf object_detection
repo,所以在这里我没有可视化。但是num_detections
当我旋转相机时,我会看到改变。
编辑:我认为opencv保存文件有问题。试试这个代码:
def dectect_func(location, id):
print('processing: ', location, id)
VID_SAVE_PATH = 'out'
# Define the video stream
cap = cv2.VideoCapture(location) # Change only if you have more than one webcams
fourcc = cv2.VideoWriter_fourcc('M','J','P','G')
out = cv2.VideoWriter(VID_SAVE_PATH + id + '.avi', fourcc, 20.0, (640,480)) #cv2.VideoWriter(VID_SAVE_PATH + id + '.avi',fourcc, 20.0, (640,480))
while True:
# Read frame from camera
ret, image_np = cap.read()
if not ret:
break
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image_np, axis=0)
# Extract image tensor
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
# Extract detection boxes
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
# Extract detection scores
scores = detection_graph.get_tensor_by_name('detection_scores:0')
# Extract detection classes
classes = detection_graph.get_tensor_by_name('detection_classes:0')
# Extract number of detectionsd
num_detections = detection_graph.get_tensor_by_name(
'num_detections:0')
# Actual detection.
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
# Visualization of the results of a detection.
vis_util.visualize_boxes_and_labels_on_image_array(
image_np,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
use_normalized_coordinates=True,
line_thickness=8)
print(num_detections)
# otherwise there will be no file saved if resolution mismatch
frame = cv2.resize(image_np, (640,480), cv2.INTER_CUBIC)
out.write(frame)
cap.release()
out.release()
cv2.destroyAllWindows()
# Detection
dectect_func('small.mp4','0')
dectect_func('small.mp4','1')
推荐阅读
- python - 循环中的“-1”是什么意思
- mongodb - 在 mongo 事务中保存许多实体会导致 WriteConflict
- android - notifyDataSetChanged 在使用套接字时使 viewpager 幻灯片滞后
- selenium - Python selenium 动作链
- python - 统计函数中的空数据框
- python - 图像上的openCV文本 - 背景图像不清晰
- python - 尝试使用 send_config_set 时无法在 Netmiko 中进入配置模式
- python - 排放概率表的最佳数据结构是什么?
- android - 在平面设计中使用国家国旗表情符号
- firebase-realtime-database - Firebase 实时数据库 - 通配符规则