google-cloud-platform - Google Cloud Vision API 对象检测模型在 Raspberry Pi 上出现总线错误
问题描述
我使用 vision API 在 Google Cloud 上训练了一个简单的对象检测模型。导出为 tflite 模型后,我尝试在 Raspberry Pi 3B+ 上使用下面的简单入门代码和 tensorflow lite 2.6.0rc-2 运行它。该代码可以正常运行标准 MobileNet 模型,但在使用我的自定义模型分配张量时会出现总线错误。然后,我尝试在 WSL debian 上使用我的模型运行相同的代码,这很有效。视觉 API 说它支持 ARM 边缘设备,所以我不明白为什么它不起作用。树莓派是不是内存不够?如果是这样,为什么它运行更复杂的 MobileNet 模型?
从https://github.com/Qengineering/TensorFlow_Lite_SSD_RPi_32-bits修改的测试代码
#include <opencv2/opencv.hpp>
#include <opencv2/dnn.hpp>
#include <opencv2/highgui.hpp>
#include <fstream>
#include <iostream>
#include <opencv2/core/ocl.hpp>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/model.h"
#include <cmath>
using namespace cv;
using namespace std;
const size_t width = 192;
const size_t height = 192;
std::vector<std::string> Labels;
std::unique_ptr<tflite::Interpreter> interpreter;
static bool getFileContent(std::string fileName)
{
// Open the File
std::ifstream in(fileName.c_str());
// Check if object is valid
if(!in.is_open()) return false;
std::string str;
// Read the next line from File untill it reaches the end.
while (std::getline(in, str))
{
// Line contains string of length > 0 then save it in vector
if(str.size()>0) Labels.push_back(str);
}
// Close The File
in.close();
return true;
}
void detect_from_video(Mat &src)
{
Mat image;
int cam_width =src.cols;
int cam_height=src.rows;
// copy image to input as input tensor
cv::resize(src, image, Size(width,height));
memcpy(interpreter->typed_input_tensor<uchar>(0), image.data, image.total() * image.elemSize());
interpreter->SetAllowFp16PrecisionForFp32(true);
interpreter->SetNumThreads(4); //quad core
// cout << "tensors size: " << interpreter->tensors_size() << "\n";
// cout << "nodes size: " << interpreter->nodes_size() << "\n";
// cout << "inputs: " << interpreter->inputs().size() << "\n";
// cout << "input(0) name: " << interpreter->GetInputName(0) << "\n";
// cout << "outputs: " << interpreter->outputs().size() << "\n";
interpreter->Invoke(); // run your model
const float* detection_locations = interpreter->tensor(interpreter->outputs()[0])->data.f;
const float* detection_classes=interpreter->tensor(interpreter->outputs()[1])->data.f;
const float* detection_scores = interpreter->tensor(interpreter->outputs()[2])->data.f;
const int num_detections = *interpreter->tensor(interpreter->outputs()[3])->data.f;
//there are ALWAYS 10 detections no matter how many objects are detectable
//cout << "number of detections: " << num_detections << "\n";
const float confidence_threshold = 0.5;
for(int i = 0; i < num_detections; i++){
if(detection_scores[i] > confidence_threshold){
int det_index = (int)detection_classes[i]+1;
float y1=detection_locations[4*i ]*cam_height;
float x1=detection_locations[4*i+1]*cam_width;
float y2=detection_locations[4*i+2]*cam_height;
float x2=detection_locations[4*i+3]*cam_width;
Rect rec((int)x1, (int)y1, (int)(x2 - x1), (int)(y2 - y1));
rectangle(src,rec, Scalar(0, 0, 255), 1, 8, 0);
putText(src, format("%s", Labels[det_index].c_str()), Point(x1, y1-5) ,FONT_HERSHEY_SIMPLEX,0.5, Scalar(0, 0, 255), 1, 8, 0);
}
}
}
int main(int argc,char ** argv)
{
float f;
float FPS[16];
int i;
int Fcnt=0;
Mat frame;
chrono::steady_clock::time_point Tbegin, Tend;
for(i=0;i<16;i++) FPS[i]=0.0;
// Load model
std::unique_ptr<tflite::FlatBufferModel> model = tflite::FlatBufferModel::BuildFromFile("detect.tflite");
// Build the interpreter
tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::InterpreterBuilder(*model.get(), resolver)(&interpreter);
interpreter->AllocateTensors();
// Get the names
bool result = getFileContent("labels.txt");
if(!result)
{
cout << "loading labels failed";
exit(-1);
}
// VideoCapture cap("James.mp4");
VideoCapture cap(0);
if (!cap.isOpened()) {
cerr << "ERROR: Unable to open the camera" << endl;
return 0;
}
cout << "Start grabbing, press ESC on Live window to terminate" << endl;
while(1){
// frame=imread("Traffic.jpg"); //need to refresh frame before dnn class detection
cap >> frame;
if (frame.empty()) {
cerr << "End of movie" << endl;
break;
}
detect_from_video(frame);
Tend = chrono::steady_clock::now();
//calculate frame rate
f = chrono::duration_cast <chrono::milliseconds> (Tend - Tbegin).count();
Tbegin = chrono::steady_clock::now();
FPS[((Fcnt++)&0x0F)]=1000.0/f;
for(f=0.0, i=0;i<16;i++){ f+=FPS[i]; }
putText(frame, format("FPS %0.2f",f/16),Point(10,20),FONT_HERSHEY_SIMPLEX,0.6, Scalar(0, 0, 255));
//show output
imshow("RPi 4 - 2.0 GHz - 2 Mb RAM", frame);
char esc = waitKey(5);
if(esc == 27) break;
}
cout << "Closing the camera" << endl;
// When everything done, release the video capture and write object
cap.release();
destroyAllWindows();
cout << "Bye!" << endl;
return 0;
}
总线错误的堆栈跟踪,发生在张量分配期间
Program terminated with signal SIGBUS, Bus error.
#0 0x00134dd0 in tflite::ops::builtin::broadcastto::ResizeOutputTensor(TfLiteContext*, tflite::ops::builtin::broadcastto::BroadcastToContext*) ()
(gdb) bt
#0 0x00134dd0 in tflite::ops::builtin::broadcastto::ResizeOutputTensor(TfLiteContext*, tflite::ops::builtin::broadcastto::BroadcastToContext*) ()
#1 0x00135194 in tflite::ops::builtin::broadcastto::Prepare(TfLiteContext*, TfLiteNode*) ()
#2 0x000d36c4 in tflite::Subgraph::PrepareOpsStartingAt(int, std::vector<int, std::allocator<int> > const&, int*) ()
#3 0x000d386c in tflite::Subgraph::PrepareOpsAndTensors() ()
#4 0x000d5c64 in tflite::Subgraph::AllocateTensors() ()
#5 0x0001b2cc in tflite::Interpreter::AllocateTensors() ()
#6 0x000161d8 in main(int, char**) (argc=1, argv=0x7ebd0644)
at MobileNetV1.cpp:106
Tflite 对象检测模型针对具有 50 个图像的单个标签类型进行了训练(我希望在添加更多图像之前让模型工作) https://storage.googleapis.com/ml-data-storage-bucket/models/model-export /iod/tflite-ping_pong_ball_1-2021-08-02T19%3A46%3A09.324008Z/model.tflite
解决方案
推荐阅读
- php - 带有闭包 CURL 标头的 PHP 类方法
- npm - 发布 npm 包的部署工作流程
- java - 如何检查文件是否存在于 java 中的 Azure Blob 容器中
- c# - 在 Linux 中找不到 PL2303 设备的序列号属性
- excel - 在 Countifs 中使用 ABS 函数
- apache-kafka - Kafka中'metric.reporters'和'kafka.metrics.reporters'属性之间的区别
- sql - R在读取数据帧时复制大整数
- angular - setValue 数据未显示在表单字段中,当我编辑列表项时
- react-native - React Native 和 Flutter 有什么区别?
- android - 当WIFI被禁用时,用户应该重新打开wifi。但我想检测整个时间是否打开或关闭。我怎样才能做到这一点?