首页 > 解决方案 > 在对象检测问题中获取 len(dataset) = 0

问题描述

我正在解决水果数据集上的对象检测问题:https ://yadi.sk/d/UPwQB7OZrB48qQ 。我得到了我的数据集类的代码:

class2tag = {"apple": 1, "orange": 2, "banana": 3}


class FruitDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.images = []
        self.annotations = []
        self.transform = transform
        for annotation in glob.glob(data_dir + "/*xml"):
            image_fname = os.path.splitext(annotation)[0] + ".jpg"
            self.images.append(cv2.cvtColor(cv2.imread(image_fname), cv2.COLOR_BGR2RGB))
            with open(annotation) as f:
                annotation_dict = xmltodict.parse(f.read())
            bboxes = []
            labels = []
            objects = annotation_dict["annotation"]["object"]
            if not isinstance(objects, list):
                objects = [objects]
            for obj in objects:
                bndbox = obj["bndbox"]
                bbox = [bndbox["xmin"], bndbox["ymin"], bndbox["xmax"], bndbox["ymax"]]
                bbox = list(map(int, bbox))
                bboxes.append(torch.tensor(bbox))
                labels.append(class2tag[obj["name"]])
            self.annotations.append(
                {"boxes": torch.stack(bboxes).float(), "labels": torch.tensor(labels)}
            )

    def __getitem__(self, i):
        if self.transform:
            # the following code is correct if you use albumentations
            # if you use torchvision transforms you have to modify it =)
            res = self.transform(
                image=self.images[i],
                bboxes=self.annotations[i]["boxes"],
                labels=self.annotations[i]["labels"],
            )
            return res["image"], {
                "boxes": torch.tensor(res["bboxes"]),
                "labels": torch.tensor(res["labels"]),
            }
        else:
            return self.images[i], self.annotations[i]

    def __len__(self):
        return len(self.images)

我在 Google Colab 中做我的项目,所以我已经安装了 Google Drive 并解压缩了存档。

from google.colab import drive
drive.mount('/content/drive')


然后我使用allementations做了一些扩充:

train_transform = A.Compose([
                             A.Flip(p=0.25),
                             A.RGBShift(p=0.2),
                             ], bbox_params=A.BboxParams(format='coco'))
val_transform = A.Compose([], bbox_params=A.BboxParams(format='coco'))

train_dataset = FruitDataset("./train_zip/train", transform=train_transform)
val_dataset = FruitDataset("./test_zip/test", transform=val_transform)

但是,当我运行时len(train_dataset),我得到的值为 0。所以,我无法理解为什么我的数据集大小为 0。我也无法理解问题出在哪里。非常感谢任何可能的建议。

标签: pythonneural-networkalbumentations

解决方案


推荐阅读