react-native - 如何使用 TensorFlow 和 React Native 显示对象检测边界框
问题描述
我正在构建一个反应原生应用程序,它使用自定义训练模型来识别对象。
我已经完成了所有设置,相机正在工作,我们导出了 Tensorflow 模型(json + bin 文件)并将它们托管在我们的 Web 服务器上,以通过“tf.loadGraphModel”将它们加载到应用程序中(到目前为止一切都很好)
调用预测,进行一些转换,为我提供了模型认为最准确的数组。是否可以使用张量流显示预测的边界框,以便在连接了 expo 相机的情况下做出本机反应(在视频源打开时实时),如果可以,如何。
此外,是否有可能更容易地获得预测类,或者这是要走的路?
我的屏幕目前看起来像这样
import { Camera } from 'expo-camera';
const TensorCamera = cameraWithTensors(Camera);
function DetectionScreen() {
const navigation = useNavigation();
const [tfReady, setTfReady] = useState(false);
const [focus, setFocus] = useState(false);
const [modelReady, setModelReady] = useState(false);
const [hasError, setHasError] = useState(false);
const [detectionModel, setDetectionModal] = useState(null);
let textureDims;
if (Platform.OS === 'ios') {
textureDims = {
height: 1920,
width: 1080,
};
} else {
textureDims = {
height: 1200,
width: 1600,
};
}
useFocusEffect(
useCallback(() => {
const loadTensor = async () => {
await tf.ready();
setTfReady(true);
let model = null;
let hasModel = false;
try {
model = await tf.loadGraphModel(
asyncStorageIO('react-native-tensor-flow-model')
);
hasModel = true;
setDetectionModal(model);
setModelReady(true);
console.log('Model loaded from storage');
} catch (e) {
console.log('Error loading model from storage');
console.log(e);
}
if (!hasModel) {
try {
model = await tf.loadGraphModel(
'https://model-web-server-domain.com/model.json'
);
// Save the model to async storage
await model.save(asyncStorageIO('react-native-tensor-flow-model'));
setDetectionModal(model);
setModelReady(true);
} catch (e) {
console.log(e);
setHasError(true);
}
}
};
setFocus(true);
loadTensor();
return () => {
setFocus(false);
};
}, [])
);
const handleCameraStream = (images, updatePreview, gl) => {
const loop = async () => {
const nextImageTensor = images.next().value;
if (detectionModel && nextImageTensor) {
try {
// needs to be expanded to match the models dims or an error occurs
const tensor4d = nextImageTensor.expandDims(0);
// needs a cast or an error occurs
const float32Tensor = tensor4d.cast('float32');
// const prediction = await detectionModel.predict(nextImageTensor);
const prediction = await detectionModel.executeAsync(float32Tensor);
if (prediction && prediction.length > 0) {
const classes = prediction[0].argMax(-1).print();
console.log('=== PREDICTION ===');
console.log(classes);
}
} catch (e) {
console.log('ERROR PREDICTING FROM MODEL');
console.log(e);
}
}
requestAnimationFrame(loop);
};
loop();
};
return (
<View style={styles.container}>
{tfReady && focus && modelReady && (
<>
<TensorCamera
// Standard Camera props
style={styles.camera}
type={Camera.Constants.Type.back}
// Tensor related props
cameraTextureHeight={textureDims.height}
cameraTextureWidth={textureDims.width}
resizeHeight={200}
resizeWidth={150}
resizeDepth={3}
onReady={handleCameraStream}
autorender={true}
/>
</>
)}
{!tfReady && <Text>Loading ...</Text>}
</View>
);
}
*** 更新 ***
关于模型和预测输出的更多信息
moel.json
{
"format": "graph-model",
"generatedBy": "2.4.0",
"convertedBy": "TensorFlow.js Converter v1.7.0",
"userDefinedMetadata": {
"signature": {
"inputs": {
"ToFloat:0": {
"name": "ToFloat:0",
"dtype": "DT_FLOAT",
"tensorShape": {
"dim": [
{ "size": "-1" },
{ "size": "-1" },
{ "size": "-1" },
{ "size": "3" }
]
}
}
},
"outputs": {
"Postprocessor/convert_scores:0": {
"name": "Postprocessor/convert_scores:0",
"dtype": "DT_FLOAT",
"tensorShape": {
"dim": [{ "size": "-1" }, { "size": "-1" }, { "size": "11" }]
}
},
"Postprocessor/Decode/transpose_1:0": {
"name": "Postprocessor/Decode/transpose_1:0",
"dtype": "DT_FLOAT",
"tensorShape": { "dim": [{ "size": "-1" }, { "size": "4" }] }
}
}
}
},
"modelTopology": {/*...*/},
"weightsManifest": [/*...*/],
}
预测输出(数据张量)
{"dataId": {"id": 22594}, "dtype": "int32", "id": 17510, "isDisposedInternal": false, "kept": false, "rankType": "3", "shape": [200, 150, 3], "size": 90000, "strides": [450, 3]}
Tensor
[[[0.0002148, 0.0004637, 0.0005074, 0.0002892, 0.0006514, 0.0002825, 0.000659 , 0.0004711, 0.0006962, 0.0002513, 0.0007014],
[0.0002115, 0.000315 , 0.0005155, 0.0002003, 0.0006719, 0.0003006, 0.000607 , 0.00035 , 0.000555 , 0.0003226, 0.0011692],
[0.0002034, 0.0005054, 0.0007887, 0.0003393, 0.0008593, 0.0003684, 0.0009112, 0.0006189, 0.0007553, 0.0006771, 0.0009623],
...,
[0.0031853, 0.0024052, 0.0078735, 0.0032234, 0.0032864, 0.0030518, 0.007637 , 0.0053635, 0.0085449, 0.0039902, 0.0059357],
[0.0031471, 0.0018387, 0.0050392, 0.0019646, 0.0024433, 0.0026016, 0.0039139, 0.0029011, 0.0051994, 0.0027256, 0.0041809],
[0.0032482, 0.0017414, 0.0041161, 0.0016489, 0.0021324, 0.001853 , 0.0030632, 0.0022793, 0.0032864, 0.0045204, 0.007637 ]]]
解决方案
推荐阅读
- docker - 为什么 GitLab CI/CD 因多阶段 Dockerfile 而失败?
- html - 初学者坚持制作有效的 html 和 css 测验的任务
- servlets - web.xml中的/和/*有什么区别
- postgresql - docker pgadmin 中的 CSRF 令牌丢失错误
- php - 如何调试 PHP 错误“解析错误:语法错误,意外的 T_CLASS”
- c# - 您可以为 C# 版本添加预处理器指令吗?
- tableau-api - 有没有办法将筛选集应用于 Tableau 中的所有 ID?
- swift - TableView 中带有部分和行的 SwiftUI .onTapGesture 问题
- json - 通过dataweave 2.0上不同json元素的位置连接数据
- reactjs - 从库中模拟一个 ES6 类