python - 如何在 Python 中使用映射泛型
问题描述
我想构建一个(抽象)数据集,它是通用的,只提供加载文件的框架。
然后子类用于特定类型(这里打扰示例并且注释属于类型np.ndarray
当我实例化一个类型的对象时,ImageDataset
我得到
File "/home/maximilian/darts/tests/test_dataset.py", line 12, in test_simple_loading
dataset = ImageDataset(dataset_root)
File "/home/maximilian/darts/darts/datasets.py", line 72, in __init__
super(ImageDataset, self).__init__(path)
File "/home/maximilian/darts/darts/datasets.py", line 19, in __init__
self.__load_data(path)
File "/home/maximilian/darts/darts/datasets.py", line 48, in __load_data
self.annotations.update(annotations)
AttributeError: 'ImageDataset' object has no attribute 'annotations'
谁能告诉我我在这里做错了什么?
from collections import defaultdict
from abc import abstractmethod
from itertools import tee
from pathlib import Path
from typing import Iterator, TypeVar, Tuple, Dict, Mapping
import numpy as np
from cv2 import haveImageReader, imread
Key = str
Annotation = TypeVar('Annotation')
Sample = TypeVar('Sample')
AnnotatedSample = Tuple[Sample, Annotation]
class Dataset(Mapping[Key, AnnotatedSample]):
def __init__(self, path: Path):
self.__path = path
self.__load_data(path)
self.annotations: Dict[str, Annotation] = defaultdict(lambda: None)
self.samples: Dict[str, Sample] = defaultdict(lambda: None)
@abstractmethod
def _is_sample_file(self, file : Path) -> bool:
raise NotImplementedError()
@abstractmethod
def _is_annotation_file(self, file : Path) -> bool:
raise NotImplementedError()
@abstractmethod
def _load_annotation(self, file: Path) -> Annotation:
raise NotImplementedError()
@abstractmethod
def _load_sample(self, file: Path) -> Sample:
raise NotImplementedError()
def __load_data(self, path: Path):
files = filter(lambda file: not file.is_dir(), path.glob('*'))
it1, it2 = tee(files)
annotations_files = filter(lambda file: self._is_annotation_file(file), it1)
sample_files = filter(lambda file: self._is_sample_file(file), it2)
annotations = map(lambda file: (file.stem, self._load_annotation(file)), annotations_files)
samples = map(lambda file: (file.stem, self._load_sample(file)), sample_files)
self.annotations.update(annotations)
self.samples.update(samples)
annotation_keys = set(self.annotations)
samples_keys = set(self.samples)
annotations_without_sample = annotation_keys.difference(samples_keys)
if annotations_without_sample:
raise ValueError(
f"For each annotation a sample file must be given. Annotation without sample {annotations_without_sample} ")
def __getitem__(self, k: Key) -> AnnotatedSample:
return self.samples[k], self.annotations
def __len__(self) -> int:
return len(self.samples)
def __iter__(self) -> Iterator[Key]:
return self.samples.keys()
class ImageDataset(Dataset[np.ndarray, np.ndarray]):
ANNOTATION_EXTENSIONS = ['.npy']
def __init__(self, path : Path, transformations = []):
super(ImageDataset, self).__init__(path)
self.__transformations = transformations
def _is_annotation_file(self, file: Path) -> bool:
return haveImageReader(str(file))
def _is_sample_file(self, file: Path) -> bool:
return file.stem in ImageDataset.ANNOTATION_EXTENSIONS
def _load_annotation(self, file: Path) -> Annotation:
return np.load(str(file))
def _load_sample(self, file: Path) -> Sample:
image = imread(str(file))
for transform in self.__transformations:
image = transform(image)
return image
解决方案
推荐阅读
- dart - Flutter json将对象的数组映射到类
- vba - Excel VBA:查找最后一行
- typescript - 如何合并对象的通用列表?
- c# - IPHostEntry - 需要强制它从另一个域返回 FQDN 主机名
- java - 在自定义验证器 Spring Rest 中返回 HTTP 代码
- sql - HAWQ PostgreSQL - 基于前一行的增量行
- python - Python:从字典中提取值并将所有值添加到新列表中
- node.js - Sequelize - 要定义外键,我应该使用引用还是 belongsTo?或两者?
- jquery - 即使我没有初始化或将其设置为值,JavaScript 对象也存在
- javascript - 使用 Office Addin 更改 MS Word 中的主菜单