首页 > 解决方案 > 如何使用 TensorFlow 和 React Native 显示对象检测边界框

问题描述

我正在构建一个反应原生应用程序,它使用自定义训练模型来识别对象。

我已经完成了所有设置,相机正在工作,我们导出了 Tensorflow 模型(json + bin 文件)并将它们托管在我们的 Web 服务器上,以通过“tf.loadGraphModel”将它们加载到应用程序中(到目前为止一切都很好)

调用预测,进行一些转换,为我提供了模型认为最准确的数组。是否可以使用张量流显示预测的边界框,以便在连接了 expo 相机的情况下做出本机反应(在视频源打开时实时),如果可以,如何。

此外,是否有可能更容易地获得预测类,或者这是要走的路?

我的屏幕目前看起来像这样

import { Camera } from 'expo-camera';
const TensorCamera = cameraWithTensors(Camera);

function DetectionScreen() {
  const navigation = useNavigation();

  const [tfReady, setTfReady] = useState(false);
  const [focus, setFocus] = useState(false);
  const [modelReady, setModelReady] = useState(false);
  const [hasError, setHasError] = useState(false);
  const [detectionModel, setDetectionModal] = useState(null);

  let textureDims;
  if (Platform.OS === 'ios') {
    textureDims = {
      height: 1920,
      width: 1080,
    };
  } else {
    textureDims = {
      height: 1200,
      width: 1600,
    };
  }

  useFocusEffect(
    useCallback(() => {
      const loadTensor = async () => {
        await tf.ready();
        setTfReady(true);

        let model = null;
        let hasModel = false;

        try {
          model = await tf.loadGraphModel(
            asyncStorageIO('react-native-tensor-flow-model')
          );
          hasModel = true;
          setDetectionModal(model);
          setModelReady(true);

          console.log('Model loaded from storage');
        } catch (e) {
          console.log('Error loading model from storage');
          console.log(e);
        }

        if (!hasModel) {
          try {
            model = await tf.loadGraphModel(
              'https://model-web-server-domain.com/model.json'
            );

            // Save the model to async storage
            await model.save(asyncStorageIO('react-native-tensor-flow-model'));
            setDetectionModal(model);
            setModelReady(true);
          } catch (e) {
            console.log(e);
            setHasError(true);
          }
        }
      };

      setFocus(true);
      loadTensor();

      return () => {
        setFocus(false);
      };
    }, [])
  );

 const handleCameraStream = (images, updatePreview, gl) => {
    const loop = async () => {
      const nextImageTensor = images.next().value;

      if (detectionModel && nextImageTensor) {
        try {
          // needs to be expanded to match the models dims or an error occurs
          const tensor4d = nextImageTensor.expandDims(0);
          // needs a cast or an error occurs
          const float32Tensor = tensor4d.cast('float32');
          // const prediction = await detectionModel.predict(nextImageTensor);
          const prediction = await detectionModel.executeAsync(float32Tensor);

          if (prediction && prediction.length > 0) {
            const classes = prediction[0].argMax(-1).print();
            console.log('=== PREDICTION ===');
            console.log(classes);
          }
        } catch (e) {
          console.log('ERROR PREDICTING FROM MODEL');
          console.log(e);
        }
      }

      requestAnimationFrame(loop);
    };
    loop();
  };

  return (
      <View style={styles.container}>
        {tfReady && focus && modelReady && (
          <>
            <TensorCamera
              // Standard Camera props
              style={styles.camera}
              type={Camera.Constants.Type.back}
              // Tensor related props
              cameraTextureHeight={textureDims.height}
              cameraTextureWidth={textureDims.width}
              resizeHeight={200}
              resizeWidth={150}
              resizeDepth={3}
              onReady={handleCameraStream}
              autorender={true}
            />
          </>
        )}
        {!tfReady && <Text>Loading ...</Text>}
      </View>
  );
}

*** 更新 ***

关于模型和预测输出的更多信息

moel.json

{
  "format": "graph-model",
  "generatedBy": "2.4.0",
  "convertedBy": "TensorFlow.js Converter v1.7.0",
  "userDefinedMetadata": {
    "signature": {
      "inputs": {
        "ToFloat:0": {
          "name": "ToFloat:0",
          "dtype": "DT_FLOAT",
          "tensorShape": {
            "dim": [
              { "size": "-1" },
              { "size": "-1" },
              { "size": "-1" },
              { "size": "3" }
            ]
          }
        }
      },
      "outputs": {
        "Postprocessor/convert_scores:0": {
          "name": "Postprocessor/convert_scores:0",
          "dtype": "DT_FLOAT",
          "tensorShape": {
            "dim": [{ "size": "-1" }, { "size": "-1" }, { "size": "11" }]
          }
        },
        "Postprocessor/Decode/transpose_1:0": {
          "name": "Postprocessor/Decode/transpose_1:0",
          "dtype": "DT_FLOAT",
          "tensorShape": { "dim": [{ "size": "-1" }, { "size": "4" }] }
        }
      }
    }
  },
  "modelTopology": {/*...*/},
  "weightsManifest": [/*...*/],
}

预测输出(数据张量)

{"dataId": {"id": 22594}, "dtype": "int32", "id": 17510, "isDisposedInternal": false, "kept": false, "rankType": "3", "shape": [200, 150, 3], "size": 90000, "strides": [450, 3]}

Tensor
    [[[0.0002148, 0.0004637, 0.0005074, 0.0002892, 0.0006514, 0.0002825, 0.000659 , 0.0004711, 0.0006962, 0.0002513, 0.0007014],
      [0.0002115, 0.000315 , 0.0005155, 0.0002003, 0.0006719, 0.0003006, 0.000607 , 0.00035  , 0.000555 , 0.0003226, 0.0011692],
      [0.0002034, 0.0005054, 0.0007887, 0.0003393, 0.0008593, 0.0003684, 0.0009112, 0.0006189, 0.0007553, 0.0006771, 0.0009623],
      ...,
      [0.0031853, 0.0024052, 0.0078735, 0.0032234, 0.0032864, 0.0030518, 0.007637 , 0.0053635, 0.0085449, 0.0039902, 0.0059357],
      [0.0031471, 0.0018387, 0.0050392, 0.0019646, 0.0024433, 0.0026016, 0.0039139, 0.0029011, 0.0051994, 0.0027256, 0.0041809],
      [0.0032482, 0.0017414, 0.0041161, 0.0016489, 0.0021324, 0.001853 , 0.0030632, 0.0022793, 0.0032864, 0.0045204, 0.007637 ]]]

标签: react-nativetensorflowtensorflow.js

解决方案


推荐阅读