首页 > 解决方案 > Tensorflow对象检测,在教程中重新排列代码时列表索引超出范围错误

问题描述

我正在尝试 Tensorflow 2 对象检测 API。

我在这个链接上运行了教程中的代码,一切运行都没有问题。

但是,我尝试重新组织该代码,现在我有类似的东西:

import os
import cv2
import numpy as np
import tarfile
import urllib.request
import tensorflow as tf
from object_detection.utils import label_map_util
from object_detection.utils import config_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.builders import model_builder

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'    # Suppress TensorFlow logging
tf.get_logger().setLevel('ERROR')           # Suppress TensorFlow logging (2)
# Enable GPU dynamic memory allocation
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

def main():

    ## create folders
    data_dir, models_dir = create_data_directories()

    ## --------------------------
    ## download and extract model
    ## --------------------------
    print('Download and extract model')
    model_date = '20200711'
    model_name = 'ssd_resnet50_v1_fpn_640x640_coco17_tpu-8'
    label_filename = 'mscoco_label_map.pbtxt'
    PATH_TO_CKPT, PATH_TO_CFG, PATH_TO_LABELS = download_models_labels(data_dir, models_dir, model_date, model_name, label_filename)

    ## --------------------------
    # Load pipeline config and build a detection model
    ## --------------------------
    configs = config_util.get_configs_from_pipeline_file(PATH_TO_CFG)
    model_config = configs['model']
    detection_model = model_builder.build(model_config=model_config, is_training=False)
    # Restore checkpoint
    ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
    ckpt.restore(os.path.join(PATH_TO_CKPT, 'ckpt-0')).expect_partial()


    ## --------------------------
    # Load label map data (for plotting)
    ## --------------------------
    category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS,
                                                                        use_display_name=True)
    ## --------------------------
    # Define the video stream
    ## --------------------------
    cap = cv2.VideoCapture(2)

    while True:
        # Read frame from camera
        ret, image_np = cap.read()

        # Expand dimensions since the model expects images to have batch -> shape: [1, None, None, 3]
        image_np_expanded = np.expand_dims(image_np, axis=0)


        detect_fn = get_model_detection_function(detection_model)
        input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
        detections, predictions_dict, shapes = detect_fn(input_tensor)

        label_id_offset = 1
        image_np_with_detections = image_np.copy()

        viz_utils.visualize_boxes_and_labels_on_image_array(
            image_np_with_detections,
            detections['detection_boxes'][0].numpy(),
            (detections['detection_classes'][0].numpy() + label_id_offset).astype(int),
            detections['detection_scores'][0].numpy(),
            category_index,
            use_normalized_coordinates=True,
            max_boxes_to_draw=200,
            min_score_thresh=.60,
            agnostic_mode=False)

        # Display output
        cv2.imshow('object detection', cv2.resize(image_np_with_detections, (800, 600)))

        if cv2.waitKey(25) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()

def create_data_directories():
    print('Create the data directory')
    DATA_DIR = os.path.join(os.getcwd(), 'data')
    MODELS_DIR = os.path.join(DATA_DIR, 'models')
    for dir in [DATA_DIR, MODELS_DIR]:
        if not os.path.exists(dir):
            os.mkdir(dir)
    return DATA_DIR, MODELS_DIR

def download_models_labels(DATA_DIR, MODELS_DIR, MODEL_DATE, MODEL_NAME, label_filename):
    # Download the model
    # ~~~~~~~~~~~~~~~~~~
    MODEL_TAR_FILENAME = MODEL_NAME + '.tar.gz'
    MODELS_DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/tf2/'
    MODEL_DOWNLOAD_LINK = MODELS_DOWNLOAD_BASE + MODEL_DATE + '/' + MODEL_TAR_FILENAME
    PATH_TO_MODEL_TAR = os.path.join(MODELS_DIR, MODEL_TAR_FILENAME)
    PATH_TO_CKPT = os.path.join(MODELS_DIR, os.path.join(MODEL_NAME, 'checkpoint/'))
    PATH_TO_CFG = os.path.join(MODELS_DIR, os.path.join(MODEL_NAME, 'pipeline.config'))
    if not os.path.exists(PATH_TO_CKPT):
        print('Downloading model. This may take a while... ', end='')
        urllib.request.urlretrieve(MODEL_DOWNLOAD_LINK, PATH_TO_MODEL_TAR)
        tar_file = tarfile.open(PATH_TO_MODEL_TAR)
        tar_file.extractall(MODELS_DIR)
        tar_file.close()
        os.remove(PATH_TO_MODEL_TAR)
        print('Done')

    # Download labels file
    LABELS_DOWNLOAD_BASE = \
        'https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/'
    PATH_TO_LABELS = os.path.join(MODELS_DIR, os.path.join(MODEL_NAME, label_filename))
    if not os.path.exists(PATH_TO_LABELS):
        print('Downloading label file... ', end='')
        urllib.request.urlretrieve(LABELS_DOWNLOAD_BASE + label_filename, PATH_TO_LABELS)
        print('Done')

    return PATH_TO_CKPT, PATH_TO_CFG, PATH_TO_LABELS



def get_model_detection_function(model):
##Get a tf.function for detection

    @tf.function
    def detect_fn(image):
        """Detect objects in image."""
        image, shapes = model.preprocess(image)
        prediction_dict = model.predict(image, shapes)
        detections = model.postprocess(prediction_dict, shapes)
        return detections, prediction_dict, tf.reshape(shapes, [-1])
    return detect_fn

if __name__ == "__main__":
    main()

因此,我只是重新排列了所有内容以使其(我认为!)更具可读性。但是,在我修改之后,我得到了错误:

 /home/lews/anaconda3/envs/tf/lib/python3.8/site-packages/object_detection/models/ssd_resnet_v1_fpn_keras_feature_extractor.py:204 preprocess  *
        if resized_inputs.shape.as_list()[3] == 3:
    IndexError: list index out of range

我在这里找到了相同问题的答案,并遵循了创建返回函数的建议detect_fn,但仍然出现错误。

显然,我可以坚持使用教程中的原始代码,但我有兴趣了解我的修改发生了什么。

标签: pythontensorflowobject-detection-api

解决方案


推荐阅读