首页 > 解决方案 > 在numpy中获取索引子集的有效方法

问题描述

我有以下索引,您可以从以下索引中获取它们np.where(...)

coords = (
  np.asarray([0 0 0 1 1 1 1 1 2 2 2 3 3 3 3 4 4 4 5 5 5 5 5 6 6 6]),
  np.asarray([2 2 8 2 2 4 4 6 2 2 6 2 2 4 6 2 2 6 2 2 4 4 6 2 2 6]),
  np.asarray([0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]),
  np.asarray([0 1 0 0 1 0 1 1 0 1 1 0 1 1 1 0 1 1 0 1 0 1 1 0 1 1])
)

另一个带有索引的元组是为了选择那些在coords

index = tuple(
  np.asarray([0 0 1 1 1 1 2 2 2 3 3 3 3 4 4 4 5 5 5 5 5 6 6 6]),
  np.asarray([2 8 2 4 4 6 2 2 6 2 2 4 6 2 2 6 2 2 4 4 6 2 2 6]),
  np.asarray([0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]),
  np.asarray([0 0 1 0 1 1 0 1 1 0 1 1 1 0 1 1 0 1 0 1 1 0 1 1])
)

例如,选择 coords[0] 是因为它在索引中(位置 0),但未coords[1]选择是因为它在index.

我可以使用以下方法轻松计算掩码[x in zip(*index) for x in zip(*coords)](从 bool 转换为 int 以获得更好的可读性):

[1 0 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]

但这对于较大的数组不是很有效。有没有更“基于numpy”的方法可以计算掩码?

标签: pythonnumpy

解决方案


您可以使用np.ravel_multi_index列压缩为更易于处理的唯一数字:

cmx = *map(np.max, coords),
imx = *map(np.max, index),
shape = np.maximum(cmx, imx) + 1

ct = np.ravel_multi_index(coords, shape)
it = np.ravel_multi_index(index, shape)

it.sort()

result = ct == it[it.searchsorted(ct)]
print(result.view(np.int8))

印刷:

[1 0 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]

推荐阅读