tensorflow - 带有 tflite 的 SSD mobilenet v1 输出不良
问题描述
背景
我正在使用来自Tensorflow 的 object detection以及Firebase 的 MLInterpreter的源代码。我试图严格遵守文档中规定的步骤。在训练期间,我可以在 TensorBoard 上看到模型正在正确训练,但不知何故,我没有正确导出和连接事物以进行推理。以下是详细信息:
我使用的命令,从训练到 .tflite 文件
首先,我使用 ssd_mobilenet_v1 配置文件提交训练作业。配置文件或多或少与 Tensorflow 默认提供的相同——我只修改了类数和存储桶名称。
gcloud ml-engine jobs submit training `whoami`_<JOB_NAME>_`date +%m_%d_%Y_%H_%M_%S` \
--runtime-version 1.12 \
--job-dir=gs://<BUCKET_NAME>/model_dir \
--packages dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz,/tmp/pycocotools/pycocotools-2.0.tar.gz \
--module-name object_detection.model_main \
--region us-central1 \
--config object_detection/samples/cloud/cloud.yml \
-- \
--model_dir=gs://<BUCKET_NAME>/model_dir \
--pipeline_config_path=gs://<BUCKET_NAME>/data/ssd_mobilenet_v1.config
然后我导出tflite_graph.pb
文件:
python models/research/object_detection/export_tflite_ssd_graph.py \
--input_type image_tensor \
--pipeline_config_path ssd_mobilenet_v1.config \
--trained_checkpoint_prefix model.ckpt-264012 \
--output_directory exported_tflite
太好了,此时我有tflite_graph.pb
,并且需要从那里获取实际.tflite
文件:
tflite_convert \
--output_file=model.tflite \
--graph_def_file=exported_tflite/tflite_graph.pb \
--input_arrays=normalized_input_image_tensor \
--output_arrays=TFLite_Detection_PostProcess \
--input_shapes=1,300,300,3 \
--allow_custom_ops
使用 Swift 和 Firebase 执行推理
我想最终使用 AVFoundation 从相机捕获图像,但为了使其更具可读性,我将仅发布代码的相关部分:
这里是初始化模型和设置 ioOptions 的地方。我在export_tflite_ssd_graph(上面使用过)的顶部发现了一条用于确定 ioOptions 的注释,但我仍然不相信我正确配置了这些:
guard let modelPath = Bundle.main.path(forResource: "model", ofType: "tflite") else {
self.interpreter = nil;
super.init()
return;
}
let localModel = CustomLocalModel(modelPath: modelPath)
self.interpreter = ModelInterpreter.modelInterpreter(localModel: localModel)
do {
try self.ioOptions.setInputFormat(index: 0, type: .float32, dimensions: [1, 300, 300, 3])
try self.ioOptions.setOutputFormat(index: 0, type: .float32, dimensions: [1, 10, 4])
try self.ioOptions.setOutputFormat(index: 1, type: .float32, dimensions: [1, 10])
try self.ioOptions.setOutputFormat(index: 2, type: .float32, dimensions: [1, 10])
try self.ioOptions.setOutputFormat(index: 3, type: .float32, dimensions: [1])
} catch let error as NSError {
print("Failed to set input or output format with error: \(error.localizedDescription)")
}
设置好之后,我稍后会使用以下几行来执行推理。基本上,我将数据缓冲区转换为 CGImage,进行一些调整大小,然后将 RGB 值重新打包到可以传递给模型进行推理的缓冲区中:
# Draw the image in a context
guard let context = CGContext(
data: nil,
width: image.width, height: image.height,
bitsPerComponent: 8, bytesPerRow: image.width * 4,
space: CGColorSpaceCreateDeviceRGB(),
bitmapInfo: CGImageAlphaInfo.noneSkipFirst.rawValue
) else {
return;
}
context.draw(image, in: CGRect(x: 0, y: 0, width: image.width, height: image.height))
guard let imageData = context.data else { return; }
# "image" is now a CGImage
let inputs = ModelInputs()
var inputData = Data()
do {
for row in 0 ..< 300 {
for col in 0 ..< 300 {
let offset = 4 * (col * context.width + row)
// (Ignore offset 0, the unused alpha channel)
let red = imageData.load(fromByteOffset: offset+1, as: UInt8.self)
let green = imageData.load(fromByteOffset: offset+2, as: UInt8.self)
let blue = imageData.load(fromByteOffset: offset+3, as: UInt8.self)
var normalizedRed = Float32(red) / 255.0
var normalizedGreen = Float(green) / 255.0
var normalizedBlue = Float(blue) / 255.0
// Append normalized values to Data object in RGB order.
let elementSize = MemoryLayout.size(ofValue: normalizedRed)
var bytes = [UInt8](repeating: 0, count: elementSize)
memcpy(&bytes, &normalizedRed, elementSize)
inputData.append(&bytes, count: elementSize)
memcpy(&bytes, &normalizedGreen, elementSize)
inputData.append(&bytes, count: elementSize)
memcpy(&bytes, &normalizedBlue, elementSize)
inputData.append(&bytes, count: elementSize)
}
}
try inputs.addInput(inputData)
} catch let error {
print("Failed to add input: \(error)")
}
guard let interpret = self.interpreter else { return; }
print("Running interpreter")
interpret.run(inputs: inputs, options: self.ioOptions) { outputs, error in
guard error == nil, let outputs = outputs else { return; }
do {
try print(outputs.output(index: 1))
try print(outputs.output(index: 2))
...
} catch let error {
print(error)
}
}
问题/问题
经过几个小时尝试将数据转换为不会引发错误的格式后,我实际上终于得到了输出。
问题是,输出概率非常低,并且类几乎从不正确。我知道我的模型比这具有更好的准确性,并且感觉在获取检查点文件和实际对 .tflite 文件运行推理之间做错了什么。
任何从事对象检测工作的人都可以看到我可能偏离了路线吗?
解决方案
推荐阅读
- reactjs - React + TS:如何获得所有孩子的身高?
- delphi - 如何在我的 Delphi 应用程序中集成 VLC 媒体播放器
- sharepoint - 样式库存在权限问题 SharePoint Online
- python-iris - 如何在 Iris 中提取一些 NEMO 海洋模型输出的区域?
- oop - OOP UML 类图,一个人有多个角色。它的结构应该是怎样的?
- c# - 在 mac 上添加未在 vs 2019 中显示的参考
- android - 在Android手机上,应用程序可以选择使用的APN吗?
- datepicker - Form中的SwiftUI DatePicker随机显示不同的文本
- python - 替换 Pandas 单元格中的多个值
- python - 实时对象跟踪 - 如何让视频在开始播放,让用户暂停,绘制边界框,然后开始跟踪?