首页 > 解决方案 > deep copy nested iterable (or improved itertools.tee for iterable of iterables)

问题描述

Preface

I have a test where I'm working with nested iterables (by nested iterable I mean iterable with only iterables as elements).

As a test cascade consider

from itertools import tee
from typing import (Any,
                    Iterable)


def foo(nested_iterable: Iterable[Iterable[Any]]) -> Any:
    ...


def test_foo(nested_iterable: Iterable[Iterable[Any]]) -> None:
    original, target = tee(nested_iterable)  # this doesn't copy iterators elements

    result = foo(target)

    assert is_contract_satisfied(result, original)


def is_contract_satisfied(result: Any,
                          original: Iterable[Iterable[Any]]) -> bool:
    ...

E.g. foo may be simple identity function

def foo(nested_iterable: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]:
    return nested_iterable

and contract is simply checks that flattened iterables have same elements

from itertools import (chain,
                       starmap,
                       zip_longest)
from operator import eq
...
flatten = chain.from_iterable


def is_contract_satisfied(result: Iterable[Iterable[Any]],
                          original: Iterable[Iterable[Any]]) -> bool:
    return all(starmap(eq,
                       zip_longest(flatten(result), flatten(original),
                                   # we're assuming that ``object()``
                                   # will create some unique object
                                   # not presented in any of arguments
                                   fillvalue=object())))

But if some of nested_iterable elements is an iterator, it may be exhausted since tee is making shallow copies, not deep ones, i.e. for given foo and is_contract_satisfied next statement

>>> test_foo([iter(range(10))])

leads to predictable

Traceback (most recent call last):
  ...
    test_foo([iter(range(10))])
  File "...", line 19, in test_foo
    assert is_contract_satisfied(result, original)
AssertionError

Problem

How to deep copy an arbitrary nested iterable?

Note

I'm aware of copy.deepcopy function, but it won't work for file objects.

标签: pythonitertoolsiterable

解决方案


天真的解决方案

简单的算法将是

  1. 执行原始嵌套迭代的元素复制。
  2. 制作n元素副本的副本。
  3. 获取与每个独立副本相关的坐标。

可以像这样实现

from itertools import tee
from operator import itemgetter
from typing import (Any,
                    Iterable,
                    Tuple,
                    TypeVar)

Domain = TypeVar('Domain')


def copy_nested_iterable(nested_iterable: Iterable[Iterable[Domain]],
                         *,
                         count: int = 2
                         ) -> Tuple[Iterable[Iterable[Domain]], ...]:
    def shallow_copy(iterable: Iterable[Domain]) -> Tuple[Iterable[Domain], ...]:
        return tee(iterable, count)

    copies = shallow_copy(map(shallow_copy, nested_iterable))
    return tuple(map(itemgetter(index), iterables)
                 for index, iterables in enumerate(copies))

优点:

  • 很容易阅读和解释。

缺点:

  • 如果我们想扩展我们对具有更高嵌套级别的迭代的方法(如嵌套迭代的迭代等),这种方法看起来没有帮助。

我们可以做得更好。

改进的解决方案

如果我们查看itertools.tee函数文档,它包含 Python 配方,在functools.singledispatch装饰器的帮助下可以重写为

from collections import (abc,
                         deque)
from functools import singledispatch
from itertools import repeat
from typing import (Iterable,
                    Tuple,
                    TypeVar)

Domain = TypeVar('Domain')


@functools.singledispatch
def copy(object_: Domain,
         *,
         count: int) -> Iterable[Domain]:
    raise TypeError('Unsupported object type: {type}.'
                    .format(type=type(object_)))

# handle general case
@copy.register(object)
# immutable strings represent a special kind of iterables
# that can be copied by simply repeating
@copy.register(bytes)
@copy.register(str)
# mappings cannot be copied as other iterables
# since they are iterable only by key
@copy.register(abc.Mapping)
def copy_object(object_: Domain,
                *,
                count: int) -> Iterable[Domain]:
    return itertools.repeat(object_, count)


@copy.register(abc.Iterable)
def copy_iterable(object_: Iterable[Domain],
                  *,
                  count: int = 2) -> Tuple[Iterable[Domain], ...]:
    iterator = iter(object_)
    # we are using `itertools.repeat` instead of `range` here
    # due to efficiency of the former
    # more info at
    # https://stackoverflow.com/questions/9059173/what-is-the-purpose-in-pythons-itertools-repeat/9098860#9098860
    queues = [deque() for _ in repeat(None, count)]

    def replica(queue: deque) -> Iterable[Domain]:
        while True:
            if not queue:
                try:
                    element = next(iterator)
                except StopIteration:
                    return
                element_copies = copy(element,
                                           count=count)
                for sub_queue, element_copy in zip(queues, element_copies):
                    sub_queue.append(element_copy)
            yield queue.popleft()

    return tuple(replica(queue) for queue in queues)

优点:

  • 处理更深层次的嵌套,甚至处理同一层次上的可迭代和不可迭代的混合元素,
  • 可以针对用户定义的结构进行扩展(例如,为它们制作独立的深层副本)。

缺点:

  • 可读性较差(但正如我们所知道的“实用性胜过纯度”),
  • 提供了一些与调度相关的开销(但没关系,因为它基于具有O(1)复杂性的字典查找)。

测试

准备

让我们定义我们的嵌套迭代如下

nested_iterable = [range(10 ** index) for index in range(1, 7)]

由于迭代器的创建没有说明底层副本的性能,让我们定义迭代器耗尽的函数(描述here

exhaust_iterable = deque(maxlen=0).extend

时间

使用timeit

import timeit

def naive(): exhaust_iterable(copy_nested_iterable(nested_iterable))

def improved(): exhaust_iterable(copy_iterable(nested_iterable))

print('naive approach:', min(timeit.repeat(naive)))
print('improved approach:', min(timeit.repeat(improved)))

我在我的笔记本电脑上安装了 Python 3.5.4 中的 Windows 10 x64

naive approach: 5.1863865
improved approach: 3.5602296000000013

记忆

使用memory_profiler

Line #    Mem usage    Increment   Line Contents
================================================
    78     17.2 MiB     17.2 MiB   @profile
    79                             def profile_memory(nested_iterable: Iterable[Iterable[Any]]) -> None:
    80     68.6 MiB     51.4 MiB       result = list(flatten(flatten(copy_nested_iterable(nested_iterable))))

对于“幼稚”的方法和

Line #    Mem usage    Increment   Line Contents
================================================
    78     17.2 MiB     17.2 MiB   @profile
    79                             def profile_memory(nested_iterable: Iterable[Iterable[Any]]) -> None:
    80     68.7 MiB     51.4 MiB       result = list(flatten(flatten(copy_iterable(nested_iterable))))

为“改进”之一。

注意:我已经制作了不同的脚本运行,因为立即制作它们不会具有代表性,因为第二个语句将重用以前创建的底层int对象。


结论

正如我们所看到的,这两个函数具有相似的性能,但最后一个支持更深层次的嵌套并且看起来非常可扩展。

广告

我已经从版本中添加了“改进”的解决方案lz0.4.0可以像这样使用

>>> from lz.replication import replicate
>>> iterable = iter(range(5))
>>> list(map(list, replicate(iterable,
                             count=3)))
[[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]

它使用hypothesisframework进行了基于属性的测试,因此我们可以确定它按预期工作。


推荐阅读