javascript - TF.js 在浏览器中为自定义类加载 model.json 格式的对象检测模型
问题描述
我正在使用 Tensorflow.js 进行对象检测。我正在尝试在浏览器中运行自定义对象检测 tensorflow.js 模型。我将 tensorflow 模型 - 推理图转换为 tensorflow.js 模型。python的预测工作正常,代码如下:
import io
import os
import scipy.misc
import numpy as np
import six
import time
import glob
from IPython.display import display
from six import BytesIO
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
import tensorflow as tf
from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
%matplotlib inline
def load_image_into_numpy_array(path):
img_data = tf.io.gfile.GFile(path, 'rb').read()
image = Image.open(BytesIO(img_data))
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8)
category_index = label_map_util.create_category_index_from_labelmap('Models/Annotation/label_map.pbtxt', use_display_name=True)
tf.keras.backend.clear_session()
model = tf.saved_model.load(f'Models/saved_model/')
def run_inference_for_single_image(model, image):
image = np.asarray(image)
input_tensor = tf.convert_to_tensor(image)
input_tensor = input_tensor[tf.newaxis,...]
output_dict = model_fn(input_tensor)
num_detections = int(output_dict.pop('num_detections'))
output_dict = {key:value[0, :num_detections].numpy()
for key,value in output_dict.items()}
output_dict['num_detections'] = num_detections
output_dict['detection_classes'] = output_dict['detection_classes'].astype(np.int64)
if 'detection_masks' in output_dict:
detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
output_dict['detection_masks'], output_dict['detection_boxes'],
image.shape[0], image.shape[1])
detection_masks_reframed = tf.cast(detection_masks_reframed > 0.5,
tf.uint8)
output_dict['detection_masks_reframed'] = detection_masks_reframed.numpy()
return output_dict
for image_path in glob.glob('images/Group_A_406.jpg'):
image_np = load_image_into_numpy_array(image_path)
output_dict = run_inference_for_single_image(model, image_np)
scores = np.squeeze(output_dict['detection_scores'])
vis_util.visualize_boxes_and_labels_on_image_array(
image_np,
output_dict['detection_boxes'],
output_dict['detection_classes'],
output_dict['detection_scores'],
category_index,
instance_masks=output_dict.get('detection_masks_reframed', None),
use_normalized_coordinates=True,
max_boxes_to_draw=50,
min_score_thresh=.45,
line_thickness=8)
display(Image.fromarray(image_np))
分享 index.html 的代码片段
<!DOCTYPE html>
<html lang="en">
<head>
<title>The Recognizer</title>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<link href="styles.css" rel="stylesheet">
</head>
<body>
<h1>Object Detection</h1>
<section id="demos">
<div id="liveView" >
<button id="webcamButton" class="invisible">Loading...</button>
<video id="webcam" class="background" playsinline crossorigin="anonymous"></video>
</div>
</section>
<!-- Import TensorFlow.js library -->
<!-- <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"></script-->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.3.0/dist/tf.min.js"></script>
<script src="script.js"></script>
</body>
</html>
script.js 的代码片段
//Store hooks and video sizes:
const video = document.getElementById('webcam');
const liveView = document.getElementById('liveView');
const demosSection = document.getElementById('demos');
const enableWebcamButton = document.getElementById('webcamButton');
const vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0)
const vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0)
var vidWidth = 0;
var vidHeight = 0;
var xStart = 0;
var yStart = 0;
// Check if webcam access is supported.
function getUserMediaSupported() {
return !!(navigator.mediaDevices &&
navigator.mediaDevices.getUserMedia);
}
// If webcam supported, add event listener to activation button:
if (getUserMediaSupported()) {
enableWebcamButton.addEventListener('click', enableCam);
} else {
console.warn('getUserMedia() is not supported by your browser');
}
// Enable the live webcam view and start classification.
function enableCam(event) {
// Only continue if model has finished loading.
if (!model) {
return;
}
// Hide the button once clicked.
enableWebcamButton.classList.add('removed');
// getUsermedia parameters to force video but not audio.
const constraints = {
video: true
};
// Stream video from VAR (for safari also)
navigator.mediaDevices.getUserMedia({
video: {
facingMode: "environment"
},
}).then(stream => {
let $video = document.querySelector('video');
$video.srcObject = stream;
$video.onloadedmetadata = () => {
vidWidth = $video.videoHeight;
vidHeight = $video.videoWidth;
//The start position of the video (from top left corner of the viewport)
xStart = Math.floor((vw - vidWidth) / 2);
yStart = (Math.floor((vh - vidHeight) / 2)>=0) ? (Math.floor((vh - vidHeight) / 2)):0;
$video.play();
//Attach detection model to loaded data event:
$video.addEventListener('loadeddata', predictWebcamTF);
}
});
}
var model = undefined;
model_url = 'https://raw.githubusercontent.com/.../model/mobile_netv2/web_model2/model.json';
//Call load function
asyncLoadModel(model_url);
//Function Loads the GraphModel type model of
async function asyncLoadModel(model_url) {
model = await tf.loadGraphModel(model_url);
console.log('Model loaded');
//Enable start button:
enableWebcamButton.classList.remove('invisible');
enableWebcamButton.innerHTML = 'Start camera';
}
var children = [];
//Perform prediction based on webcam using Layer model model:
function predictWebcamTF() {
// Now let's start classifying a frame in the stream.
detectTFMOBILE(video).then(function () {
// Call this function again to keep predicting when the browser is ready.
window.requestAnimationFrame(predictWebcamTF);
});
}
const imageSize = 300;
//Match prob. threshold for object detection:
var classProbThreshold = 0.4;//40%
//Image detects object that matches the preset:
async function detectTFMOBILE(imgToPredict) {
//Get next video frame:
await tf.nextFrame();
//Create tensor from image:
const tfImg = tf.browser.fromPixels(imgToPredict);
//Create smaller image which fits the detection size
const smallImg = tf.image.resizeBilinear(tfImg, [vidHeight,vidWidth]);
const resized = tf.cast(smallImg, 'int32');
var tf4d_ = tf.tensor4d(Array.from(resized.dataSync()), [1,vidHeight, vidWidth, 3]);
const tf4d = tf.cast(tf4d_, 'int32');
//Perform the detection with your layer model:
let predictions = await model.executeAsync(tf4d);
//Draw box around the detected object:
renderPredictionBoxes(predictions[4].dataSync(), predictions[1].dataSync(), predictions[2].dataSync());
//Dispose of the tensors (so it won't consume memory)
tfImg.dispose();
smallImg.dispose();
resized.dispose();
tf4d.dispose();
}
//Function Renders boxes around the detections:
function renderPredictionBoxes (predictionBoxes, predictionClasses, predictionScores)
{
//Remove all detections:
for (let i = 0; i < children.length; i++) {
liveView.removeChild(children[i]);
}
children.splice(0);
//Loop through predictions and draw them to the live view if they have a high confidence score.
for (let i = 0; i < 99; i++) {
//If we are over 66% sure we are sure we classified it right, draw it!
const minY = (predictionBoxes[i * 4] * vidHeight+yStart).toFixed(0);
const minX = (predictionBoxes[i * 4 + 1] * vidWidth+xStart).toFixed(0);
const maxY = (predictionBoxes[i * 4 + 2] * vidHeight+yStart).toFixed(0);
const maxX = (predictionBoxes[i * 4 + 3] * vidWidth+xStart).toFixed(0);
const score = predictionScores[i * 3] * 100;
const width_ = (maxX-minX).toFixed(0);
const height_ = (maxY-minY).toFixed(0);
//If confidence is above 70%
if (score > 70 && score < 100){
const highlighter = document.createElement('div');
highlighter.setAttribute('class', 'highlighter');
highlighter.style = 'left: ' + minX + 'px; ' +
'top: ' + minY + 'px; ' +
'width: ' + width_ + 'px; ' +
'height: ' + height_ + 'px;';
highlighter.innerHTML = '<p>'+Math.round(score) + '% ' + 'Your Object Name'+'</p>';
liveView.appendChild(highlighter);
children.push(highlighter);
}
}
}
我正在努力为自定义训练类重写 .js 代码。此外,我无法追踪需要在 .js 文件中提及的张量形状。我在 4 个自定义类上使用 ssd mobilenetv2 320*320 进行了微调。
提前致谢
解决方案
推荐阅读
- android - 如何实现下图中提供的类似 facebook 的弹出窗口?尝试使用警报对话框
- javascript - 在引导程序 datetimepicker 3 上显示 24 小时制时钟
- wpf - 过滤 observablecollection (mvvm)
- c++ - 如何在给定四个点的情况下校准相机焦距、平移和旋转?
- url - 如何将编码的字符串传递给php中的url?
- dart - CircularProgressIndicator 不显示在颤振中
- php - 未在 eBay 重定向 URL 上获得响应代码
- reactjs - 重置redux-form时react-select不清除值
- android - sqlite数据库中的图像索引
- php - Laravel 方法 Illuminate\Mail\Mailer::test 不存在