首页 > 解决方案 > 如何在 Python 中使用映射泛型

问题描述

我想构建一个(抽象)数据集,它是通用的,只提供加载文件的框架。

然后子类用于特定类型(这里打扰示例并且注释属于类型np.ndarray

当我实例化一个类型的对象时,ImageDataset我得到

 File "/home/maximilian/darts/tests/test_dataset.py", line 12, in test_simple_loading
    dataset = ImageDataset(dataset_root)
  File "/home/maximilian/darts/darts/datasets.py", line 72, in __init__
    super(ImageDataset, self).__init__(path)
  File "/home/maximilian/darts/darts/datasets.py", line 19, in __init__
    self.__load_data(path)
  File "/home/maximilian/darts/darts/datasets.py", line 48, in __load_data
    self.annotations.update(annotations)
AttributeError: 'ImageDataset' object has no attribute 'annotations'

谁能告诉我我在这里做错了什么?

from collections import defaultdict

from abc import abstractmethod
from itertools import tee
from pathlib import Path
from typing import Iterator, TypeVar, Tuple, Dict, Mapping
import numpy as np
from cv2 import haveImageReader, imread

Key = str
Annotation = TypeVar('Annotation')
Sample = TypeVar('Sample')
AnnotatedSample = Tuple[Sample, Annotation]


class Dataset(Mapping[Key, AnnotatedSample]):
    def __init__(self, path: Path):
        self.__path = path
        self.__load_data(path)
        self.annotations: Dict[str, Annotation] = defaultdict(lambda: None)
        self.samples: Dict[str, Sample] = defaultdict(lambda: None)

    @abstractmethod
    def _is_sample_file(self, file : Path) -> bool:
        raise NotImplementedError()

    @abstractmethod
    def _is_annotation_file(self, file : Path) -> bool:
        raise NotImplementedError()

    @abstractmethod
    def _load_annotation(self, file: Path) -> Annotation:
        raise NotImplementedError()

    @abstractmethod
    def _load_sample(self, file: Path) -> Sample:
        raise NotImplementedError()

    def __load_data(self, path: Path):
        files = filter(lambda file: not file.is_dir(), path.glob('*'))
        it1, it2 = tee(files)
        annotations_files = filter(lambda file: self._is_annotation_file(file), it1)
        sample_files = filter(lambda file: self._is_sample_file(file), it2)

        annotations = map(lambda file: (file.stem, self._load_annotation(file)), annotations_files)
        samples = map(lambda file: (file.stem, self._load_sample(file)), sample_files)

        self.annotations.update(annotations)
        self.samples.update(samples)

        annotation_keys = set(self.annotations)
        samples_keys = set(self.samples)
        annotations_without_sample = annotation_keys.difference(samples_keys)
        if annotations_without_sample:
            raise ValueError(
                f"For each annotation a sample file must be given. Annotation without sample {annotations_without_sample} ")

    def __getitem__(self, k: Key) -> AnnotatedSample:
        return self.samples[k], self.annotations

    def __len__(self) -> int:
        return len(self.samples)

    def __iter__(self) -> Iterator[Key]:
        return self.samples.keys()


class ImageDataset(Dataset[np.ndarray, np.ndarray]):
    ANNOTATION_EXTENSIONS = ['.npy']

    def __init__(self, path : Path, transformations = []):
        super(ImageDataset, self).__init__(path)
        self.__transformations = transformations

    def _is_annotation_file(self, file: Path) -> bool:
        return haveImageReader(str(file))

    def _is_sample_file(self, file: Path) -> bool:
        return file.stem in ImageDataset.ANNOTATION_EXTENSIONS

    def _load_annotation(self, file: Path) -> Annotation:
        return np.load(str(file))

    def _load_sample(self, file: Path) -> Sample:
        image = imread(str(file))
        for transform in self.__transformations:
            image = transform(image)
        return image

标签: pythongenericspython-typing

解决方案


推荐阅读