首页 > 解决方案 > 如何组合 n 维切片以获得规范的 ndslice

问题描述

考虑x由两个连续切片形成的 ndarray 的一部分(我在示例中使用 numpy,但问题更笼统。我实际上在我的应用程序中使用 pytorch。):

import numpy as np
x = np.arange(4 * 10 * 12 * 7).reshape(4, 10, 12, 7)
first = (slice(None), 3, slice(3, 9))
second = (2, slice(1, 3), slice(5))
out = x[first][second]

我想要一种方法来获得一个规范x.shape的、组合的 ndslice 作为 和 的first函数second。例如

combined = compose(x.shape, first, second)
assert np.equal(x[first][second], x[combined]).all()
assert combined == (2, 3, slice(4, 6), slice(5))

我只对“简单”的 ndslice 感兴趣,包括:单个 int、单个 slice 对象或 int 和 slice 对象的任意组合的元组。

通过规范,我的意思是生成的组合切片应该唯一标识x. 例如,这是访问同一段的另一种方法,它应该导致相同的组合 ndslice:

other_first = (slice(2, 4), slice(None), slice(2, 7))
other_second = (0, 3, slice(2, -1), slice(5))
combined = compose(x.shape, other_first, other_second)
assert np.equal(x[other_first][other_second], x[combined]).all()
assert combined == (2, 3, slice(4, 6), slice(5))

因为切片支持“无”和负索引,我们需要形状x才能获得规范的 ndslice。

相关问题

请注意,我对与其他地方讨论x[first][second]的不同(通常)感兴趣。x[first, second]

标签: pythonpytorchslicenumpy-ndarray

解决方案


如果有人碰巧知道内置或更优雅的解决方案(即使特定于 numpy 或 pytorch),我仍然很感兴趣,但这是我提出的通用的本土解决方案:

def compose_single(lhs, rhs, length):
    out = range(length)[lhs][rhs]
    return out if isinstance(out, int) else slice(out.start, out.stop, out.step)

def compose(shape, first, second):
    def ensure_tuple(ndslice):
        return ndslice if isinstance(ndslice, tuple) else (ndslice,)

    first = ensure_tuple(first)
    second = ensure_tuple(second)

    out = list(first) + [slice(None)] * (len(shape) - len(first))
    remaining_dims = [i for i, s in enumerate(out) if isinstance(s, slice)]
    for i, rhs in zip(remaining_dims, second):
        out[i] = compose_single(out[i], rhs, length=shape[i])
    return tuple(out)

请注意,规范输出不会使用负数或 None 开始或结束。所以我更新了下面的测试用例:

shape = (4, 10, 12, 7)
first = (slice(None), 3, slice(3, 9))
second = (2, slice(1, 3), slice(5))
expected_combined = (2, 3, slice(4, 6, 1), slice(0, 5, 1))

assert compose(shape, first, second) == expected_combined

other_first = (slice(2, 4), slice(None), slice(2, 7))
other_second = (0, 3, slice(2, -1), slice(5))

assert compose(shape, other_first, other_second) == expected_combined

推荐阅读