首页 > 解决方案 > 使用 tf.data 的对象检测管道

问题描述

如果有人能推动我采取正确的做法,将不胜感激

我正在尝试构建数据管道来训练对象检测模型。我想为此目的使用 tf.data.Dataset。无论我如何处理这个问题,我都面临着挑战。以下代码是我最接近目标的代码,但它仅适用于批量大小为 1(2 或更多给我一个批处理错误)。一旦批处理中有多个图像,每个图像都有不同数量的边界框,我就会开始收到错误消息:

InvalidArgumentError: Cannot add tensor to the batch: number of elements does not match. Shapes are: [tensor]: [5,4], [batch]: [7,4]

在上述错误中,第一张图像有 7 个边界框,而第二张图像有 5 个。

以下是我构建管道的最新代码。

def data_generator():
    ''' data generator to be used to generate training data
    
      Returns: a dictionary containing 
         img_path: string containing the path for the image
         all_bbox: nx4 numpy array 
         cls_lbl: nx1 numpy array
                  ** where n is the number of objects in the image
    '''
    while True:
        for img_path in image_paths:
            lbl_file = label_source + '/' + os.path.basename(img_path).replace('.png', '.txt')
            lbl_df = pd.read_csv(lbl_file, sep=r'\s', header=None, engine='python')
            all_bbox = []
            cls_lbl = []
            # loading bounding boxes and bbox_classes
            for r in lbl_df.iterrows():
                if r[1][0] in ['Misc', 'DontCare']:
                    continue
                else:
                    x_t, y_t = int(r[1][4]), int(r[1][5])
                    x_b, y_b = int(r[1][6]), int(r[1][7])
                    all_bbox.append([x_t, y_t, x_b, y_b])
                    cls_lbl.append(class2lbl[r[1][0]])
            yield {"img_path": img_path, 
                   "all_bbox": np.array(all_bbox),
                   "cls_lbl": np.array(cls_lbl)}

def image_loader(sample):
    '''load the image from the file and return a dictionary
    
    '''
    raw_img = tf.io.read_file(sample['img_path'])
    img = tf.io.decode_png(raw_img)
    sample["img"] =img
    sample["all_bbox"] = tf.cast(sample['all_bbox'], dtype=tf.float32)
    sample["cls_lbl"] = tf.cast(sample['cls_lbl'], dtype=tf.float32)
    return sample

train_dataset = tf.data.Dataset.from_generator(data_generator, output_types={"img_path":tf.string, 
                                                                             "all_bbox":tf.float32, 
                                                                             "cls_lbl":tf.float32})
train_dataset = train_dataset.map(image_loader)
train_dataset = train_dataset.batch(2)

val = next(iter(train_dataset))

标签: object-detectiontf.data.dataset

解决方案


推荐阅读