python - 使用opencv、tensorflow和python进行人体检测
问题描述
我正在从事一个机器人项目,该项目涉及检测人体,我正在使用张量流和预定义的数据集来创建训练模型。由于我是机器学习的新手,我无法正确地从分类器中获取输出。我只需要人员检测,并希望避免检测到球、笔记本电脑或其他物体。现在我的网络摄像头检测到所有物体,如球、球棒、笔记本电脑、电视等。我需要的输出只有得分为 80% 及以上的人。
我用于使用创建模型的代码是
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
from utils import label_map_util
from utils import visualization_utils as vis_util
MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'
MODEL_FILE = MODEL_NAME + '.tar.gz'
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
NUM_CLASSES = 90
if not os.path.exists(MODEL_NAME + '/frozen_inference_graph.pb'):
print ('Downloading the model')
opener = urllib.request.URLopener()
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
file_name = os.path.basename(file.name)
if 'frozen_inference_graph.pb' in file_name:
tar_file.extract(file, os.getcwd())
print ('Download complete')
else:
print ('Model already exists')
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
import cv2
cap = cv2.VideoCapture(1)
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
ret = True
while (ret):
ret,image_np = cap.read()
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
vis_util.visualize_boxes_and_labels_on_image_array(image_np,np.squeeze(boxes),np.squeeze(classes).astype(np.int32),np.squeeze(scores),category_index,use_normalized_coordinates=True,line_thickness=8)
cv2.imshow('image',cv2.resize(image_np,(1280,960)))
if cv2.waitKey(27) & 0xFF == ord('q'):
cv2.destroyAllWindows()
cap.release()
break
谁能解释一下我怎样才能只检测到准确度得分大于 80% 的人。
解决方案
正如我从此处的文档中看到的那样,您只需检查 person 类。现在vis_util
检查所有类。您必须if
仅为 person 类添加条件。下面给出了适当的标识符(取自文档)。
item {
name: "/m/01g317"
id: 1
display_name: "person"
}
推荐阅读
- barcode - 来自 GS128 条码 / GTIN 的产品名称
- php - 如何将变量从一种方法传递到另一个内部控制器
- android - React-Native:方向在签名组件的安装上被锁定
- excel - 在 Excel 的列表中查找第 n 个匹配项
- laravel - SMTP BLOCK GMAIL ACCOUNT(邮件转到垃圾邮件和退回邮件)
- linux - 来自 Docker 的调用堆栈显示了我的本地路径。正常吗?
- d3.js - 如何在不使用定心力的情况下使节点居中?
- bash - Git bash - 根据回购不同的颜色
- powershell - 脚本中的事件日志
- java - textToBePresentInElement() 不能使用精确的文本