首页 > 解决方案 > TF.js 在浏览器中为自定义类加载 model.json 格式的对象检测模型

问题描述

我正在使用 Tensorflow.js 进行对象检测。我正在尝试在浏览器中运行自定义对象检测 tensorflow.js 模型。我将 tensorflow 模型 - 推理图转换为 tensorflow.js 模型。python的预测工作正常,代码如下:

import io
import os
import scipy.misc
import numpy as np
import six
import time
import glob
from IPython.display import display

from six import BytesIO

import matplotlib
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont

import tensorflow as tf
from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util

%matplotlib inline
def load_image_into_numpy_array(path):

  img_data = tf.io.gfile.GFile(path, 'rb').read()
  image = Image.open(BytesIO(img_data))
  (im_width, im_height) = image.size
  return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)
category_index = label_map_util.create_category_index_from_labelmap('Models/Annotation/label_map.pbtxt', use_display_name=True)

tf.keras.backend.clear_session()
model = tf.saved_model.load(f'Models/saved_model/')
def run_inference_for_single_image(model, image):
  image = np.asarray(image)
 
  input_tensor = tf.convert_to_tensor(image)
  input_tensor = input_tensor[tf.newaxis,...]

  
   output_dict = model_fn(input_tensor)


  num_detections = int(output_dict.pop('num_detections'))
  output_dict = {key:value[0, :num_detections].numpy() 
                 for key,value in output_dict.items()}
  output_dict['num_detections'] = num_detections

  
  output_dict['detection_classes'] = output_dict['detection_classes'].astype(np.int64)
   
 
  if 'detection_masks' in output_dict:
    
    detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
              output_dict['detection_masks'], output_dict['detection_boxes'],
               image.shape[0], image.shape[1])      
    detection_masks_reframed = tf.cast(detection_masks_reframed > 0.5,
                                       tf.uint8)
    output_dict['detection_masks_reframed'] = detection_masks_reframed.numpy()
    
  return output_dict

for image_path in glob.glob('images/Group_A_406.jpg'):
  image_np = load_image_into_numpy_array(image_path)
  output_dict = run_inference_for_single_image(model, image_np)
  scores = np.squeeze(output_dict['detection_scores'])
  vis_util.visualize_boxes_and_labels_on_image_array(
      image_np,
      output_dict['detection_boxes'],
      output_dict['detection_classes'],
      output_dict['detection_scores'],
      category_index,
      instance_masks=output_dict.get('detection_masks_reframed', None),
      use_normalized_coordinates=True,
      max_boxes_to_draw=50,
       min_score_thresh=.45,
      line_thickness=8)
  display(Image.fromarray(image_np))


分享 index.html 的代码片段

<!DOCTYPE html>
<html lang="en">
<head>
    <title>The Recognizer</title>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <link href="styles.css" rel="stylesheet">
</head>
<body>
<h1>Object Detection</h1>
<section id="demos">
    <div id="liveView" >
    <button id="webcamButton" class="invisible">Loading...</button>
        <video id="webcam" class="background" playsinline crossorigin="anonymous"></video>
    </div>
</section>
<!-- Import TensorFlow.js library -->
<!-- <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"></script-->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.3.0/dist/tf.min.js"></script>
<script src="script.js"></script>
</body>
</html>

script.js 的代码片段

//Store hooks and video sizes:
const video = document.getElementById('webcam');
const liveView = document.getElementById('liveView');
const demosSection = document.getElementById('demos');
const enableWebcamButton = document.getElementById('webcamButton');
const vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0)
const vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0)
var vidWidth = 0;
var vidHeight = 0;
var xStart = 0;
var yStart = 0;




// Check if webcam access is supported.
function getUserMediaSupported() {
    return !!(navigator.mediaDevices &&
        navigator.mediaDevices.getUserMedia);
}

// If webcam supported, add event listener to activation button:
if (getUserMediaSupported()) {
    enableWebcamButton.addEventListener('click', enableCam);
} else {
    console.warn('getUserMedia() is not supported by your browser');
}

// Enable the live webcam view and start classification.
function enableCam(event) {
  // Only continue if model has finished loading.
  if (!model) {
    return;
  }
  // Hide the button once clicked.
  enableWebcamButton.classList.add('removed');

  // getUsermedia parameters to force video but not audio.
  const constraints = {
    video: true
  };

  // Stream video from VAR (for safari also)
  navigator.mediaDevices.getUserMedia({
    video: {
      facingMode: "environment"
    },
  }).then(stream => {
    let $video = document.querySelector('video');
    $video.srcObject = stream;
    $video.onloadedmetadata = () => {
      vidWidth = $video.videoHeight;
      vidHeight = $video.videoWidth;
      //The start position of the video (from top left corner of the viewport)
      xStart = Math.floor((vw - vidWidth) / 2);
      yStart = (Math.floor((vh - vidHeight) / 2)>=0) ? (Math.floor((vh - vidHeight) / 2)):0;
      $video.play();
      //Attach detection model to loaded data event:
      $video.addEventListener('loadeddata', predictWebcamTF);
    }
  });
}



var model = undefined;
model_url = 'https://raw.githubusercontent.com/.../model/mobile_netv2/web_model2/model.json';
//Call load function
asyncLoadModel(model_url);

//Function Loads the GraphModel type model of
async function asyncLoadModel(model_url) {
    model = await tf.loadGraphModel(model_url);
    console.log('Model loaded');
    //Enable start button:
    enableWebcamButton.classList.remove('invisible');
    enableWebcamButton.innerHTML = 'Start camera';
}



var children = [];
//Perform prediction based on webcam using Layer model model:
function predictWebcamTF() {
    // Now let's start classifying a frame in the stream.
    detectTFMOBILE(video).then(function () {
        // Call this function again to keep predicting when the browser is ready.
        window.requestAnimationFrame(predictWebcamTF);
    });
}
const imageSize = 300;
//Match prob. threshold for object detection:
var classProbThreshold = 0.4;//40%
//Image detects object that matches the preset:
async function detectTFMOBILE(imgToPredict) {

    //Get next video frame:
    await tf.nextFrame();
    //Create tensor from image:
    const tfImg = tf.browser.fromPixels(imgToPredict);

    //Create smaller image which fits the detection size
    const smallImg = tf.image.resizeBilinear(tfImg, [vidHeight,vidWidth]);


    const resized = tf.cast(smallImg, 'int32');
    var tf4d_ = tf.tensor4d(Array.from(resized.dataSync()), [1,vidHeight, vidWidth, 3]);
    const tf4d = tf.cast(tf4d_, 'int32');

    //Perform the detection with your layer model:
    let predictions = await model.executeAsync(tf4d);
    
    //Draw box around the detected object:
    renderPredictionBoxes(predictions[4].dataSync(), predictions[1].dataSync(), predictions[2].dataSync());
    //Dispose of the tensors (so it won't consume memory)
    tfImg.dispose();
    smallImg.dispose();
    resized.dispose();
    tf4d.dispose();
    
}



//Function Renders boxes around the detections:
function renderPredictionBoxes (predictionBoxes, predictionClasses, predictionScores)
{
    //Remove all detections:
    for (let i = 0; i < children.length; i++) {
        liveView.removeChild(children[i]);
    }
    children.splice(0);
//Loop through predictions and draw them to the live view if they have a high confidence score.
    for (let i = 0; i < 99; i++) {
//If we are over 66% sure we are sure we classified it right, draw it!
        const minY = (predictionBoxes[i * 4] * vidHeight+yStart).toFixed(0);
        const minX = (predictionBoxes[i * 4 + 1] * vidWidth+xStart).toFixed(0);
        const maxY = (predictionBoxes[i * 4 + 2] * vidHeight+yStart).toFixed(0);
        const maxX = (predictionBoxes[i * 4 + 3] * vidWidth+xStart).toFixed(0);
        const score = predictionScores[i * 3] * 100;
const width_ = (maxX-minX).toFixed(0);
        const height_ = (maxY-minY).toFixed(0);
//If confidence is above 70%
        if (score > 70 && score < 100){
            const highlighter = document.createElement('div');
            highlighter.setAttribute('class', 'highlighter');
            highlighter.style = 'left: ' + minX + 'px; ' +
                'top: ' + minY + 'px; ' +
                'width: ' + width_ + 'px; ' +
                'height: ' + height_ + 'px;';
            highlighter.innerHTML = '<p>'+Math.round(score) + '% ' + 'Your Object Name'+'</p>';
            liveView.appendChild(highlighter);
            children.push(highlighter);
        }
    }
}


我正在努力为自定义训练类重写 .js 代码。此外,我无法追踪需要在 .js 文件中提及的张量形状。我在 4 个自定义类上使用 ssd mobilenetv2 320*320 进行了微调。

提前致谢

标签: javascriptobject-detectiontensorflow.jstransfer-learning

解决方案


推荐阅读