首页 > 解决方案 > Tensorflow 对象检测 api 示例不适用于 Tensorflow 1x

问题描述

这是我的第一个 ML 项目,所以我可能在理解上犯了错误。我已经使用教程Tensorflow 对象检测 API Github训练了一个自定义模型。我很困惑,因为这个页面说当前对象检测 api 不支持 Tensorflow 2.x。但是,当我阅读演示时 ,它要求安装 Tensorflow 2.x。

我不明白什么时候物体检测api不支持tensorflow v2.x,为什么demo要求安装tensorflow v2.x?你们能帮我理解这个吗?我肯定错过了什么。

编辑 1:当我尝试使用 Tensorflow 1.15 运行演示脚本时,出现以下错误

File "object_detection_custom.py", line 71, in run_inference_for_single_image
    num_detections = int(output_dict.pop('num_detections'))
TypeError: int() argument must be a string or a number, not 'Tensor'

编辑2:下面是模型调用的输出

{
    u 'detection_boxes': < tf.Tensor 'StatefulPartitionedCall:0' shape = ( ? , 100, 4) dtype = float32 > , 
    u 'detection_classes': < tf.Tensor 'StatefulPartitionedCall:1' shape = ( ? , 100) dtype = float32 > , 
    u 'raw_detection_scores': < tf.Tensor 'StatefulPartitionedCall:6' shape = ( ? , ? , 2) dtype = float32 > , 
    u 'detection_scores': < tf.Tensor 'StatefulPartitionedCall:3' shape = ( ? , 100) dtype = float32 > , 
    u 'detection_multiclass_scores': < tf.Tensor 'StatefulPartitionedCall:2' shape = ( ? , 100, 2) dtype = float32 > , 
    u 'num_detections': < tf.Tensor 'StatefulPartitionedCall:4' shape = ( ? , ) dtype = float32 > , 
    u 'raw_detection_boxes': < tf.Tensor 'StatefulPartitionedCall:5' shape = ( ? , ? , 4) dtype = float32 >
}

标签: tensorflowobject-detectionobject-detection-api

解决方案


您上面引用的演示链接已损坏,但您引用的笔记本已移动。本教程确实使用了 TensorFlow 2,但仅用于推理。训练需要 TensorFlow 1.15。


推荐阅读