python - 如何组合 n 维切片以获得规范的 ndslice
问题描述
考虑x
由两个连续切片形成的 ndarray 的一部分(我在示例中使用 numpy,但问题更笼统。我实际上在我的应用程序中使用 pytorch。):
import numpy as np
x = np.arange(4 * 10 * 12 * 7).reshape(4, 10, 12, 7)
first = (slice(None), 3, slice(3, 9))
second = (2, slice(1, 3), slice(5))
out = x[first][second]
我想要一种方法来获得一个规范x.shape
的、组合的 ndslice 作为 和 的first
函数second
。例如
combined = compose(x.shape, first, second)
assert np.equal(x[first][second], x[combined]).all()
assert combined == (2, 3, slice(4, 6), slice(5))
我只对“简单”的 ndslice 感兴趣,包括:单个 int、单个 slice 对象或 int 和 slice 对象的任意组合的元组。
通过规范,我的意思是生成的组合切片应该唯一标识x
. 例如,这是访问同一段的另一种方法,它应该导致相同的组合 ndslice:
other_first = (slice(2, 4), slice(None), slice(2, 7))
other_second = (0, 3, slice(2, -1), slice(5))
combined = compose(x.shape, other_first, other_second)
assert np.equal(x[other_first][other_second], x[combined]).all()
assert combined == (2, 3, slice(4, 6), slice(5))
因为切片支持“无”和负索引,我们需要形状x
才能获得规范的 ndslice。
相关问题
请注意,我对与其他地方讨论x[first][second]
的不同(通常)感兴趣。x[first, second]
解决方案
如果有人碰巧知道内置或更优雅的解决方案(即使特定于 numpy 或 pytorch),我仍然很感兴趣,但这是我提出的通用的本土解决方案:
def compose_single(lhs, rhs, length):
out = range(length)[lhs][rhs]
return out if isinstance(out, int) else slice(out.start, out.stop, out.step)
def compose(shape, first, second):
def ensure_tuple(ndslice):
return ndslice if isinstance(ndslice, tuple) else (ndslice,)
first = ensure_tuple(first)
second = ensure_tuple(second)
out = list(first) + [slice(None)] * (len(shape) - len(first))
remaining_dims = [i for i, s in enumerate(out) if isinstance(s, slice)]
for i, rhs in zip(remaining_dims, second):
out[i] = compose_single(out[i], rhs, length=shape[i])
return tuple(out)
请注意,规范输出不会使用负数或 None 开始或结束。所以我更新了下面的测试用例:
shape = (4, 10, 12, 7)
first = (slice(None), 3, slice(3, 9))
second = (2, slice(1, 3), slice(5))
expected_combined = (2, 3, slice(4, 6, 1), slice(0, 5, 1))
assert compose(shape, first, second) == expected_combined
other_first = (slice(2, 4), slice(None), slice(2, 7))
other_second = (0, 3, slice(2, -1), slice(5))
assert compose(shape, other_first, other_second) == expected_combined
推荐阅读
- core-data - 我可以在 Picker() 中使用 @FetchRequest 和 @State 变量吗?
- bash - 在 pyxtermjs 的一个端口中使用两个不同的 bash 终端
- css - 如何在不切断第一个元素的情况下使水平滚动的 div 居中?
- django - 如何在 wagtail admin 中更改一些单词翻译?
- python - Python - 从文件系统快速读取
- javascript - 如何将此导入更改为要求?
- sql - 将两个 SQL 查询组合成一个 SQL 查询语句
- php - 无法将数据传递给视图
- android - 如何获取 Firebase 身份验证创建日期
- angular - Firestore 错误 - firebase:emulator:start - Firestore Emulator 已退出,代码为:1