首页 > 解决方案 > 我需要在 torch.vison 的行级别中为 21 个班级写什么?

问题描述

在这段代码中,我发现 line labels = torch.ones((records.shape[0],), dtype=torch.int64) ,只有一个类,在 Faster RCNN 的情况下,0 是为背景保留的。21节课是什么?

    def __getitem__(self, index: int):

    file_name = self.file_names[index]
    records = self.data[self.data['file_name'] == file_name]
    
    image = np.array(Image.open(file_name), dtype=np.float32)
    image /= 255.0

    if self.transform:
        image = self.transform(image)  
        
    if self.mode != "test":
        boxes = records[['xmin', 'ymin', 'xmax', 'ymax']].values
        
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        area = torch.as_tensor(area, dtype=torch.float32)

        labels = torch.ones((records.shape[0],), dtype=torch.int64)
        
        iscrowd = torch.zeros((records.shape[0],), dtype=torch.int64)
        
        target = {}

        target['boxes'] = boxes
        target['labels'] = labels
        target['image_id'] = torch.tensor([index])
        target['area'] = area
        target['iscrowd'] = iscrowd 
        target['boxes'] = torch.stack(list((map(torch.tensor, target['boxes'])))).type(torch.float32)

        return image, target, file_name
    else:
        return image, file_name

标签: computer-vision

解决方案


推荐阅读