首页 > 解决方案 > 拆分 3D numpy 数组以获取具有满足条件的连续值的所有系列

问题描述

我想在 3D 数组中沿轴 0 获取满足条件的所有数字序列。如果可能的话,我试图对所有内容进行矢量化以避免非常低效的循环。

以下代码有效,它采用数组 a 定义,例如形状 (6, 2, 3)。它在其上应用一个掩码,由数组 b 定义(相同的维度)。

然后,我沿轴 0 对我的 3D 数组进行切片,这样我就有 2*3 个 1D 切片,从而获得 6 个形状为 (6,) 的 1D 数组。为了做到这一点,我使用了一个循环,这显然会成为更大数组的效率问题。

然后我根据获得的掩码拆分我的数组,并选择(现在简单打印)具有至少 3 个连续值满足 b 数组给定条件的系列。

import numpy as np

a = np.array([[[0.57337127, 0.7626088, 0.26965987],
               [0.66987041, 0.2914202, 0.62678441]],
              [[0.97442524, 0.61656519, 0.10544983],
               [0.05780219, 0.00381356, 0.57118615]],
              [[0.47069657, 0.36802822, 0.67483419],
               [0.32773146, 0.99773064, 0.56042508]],
              [[0.70984651, 0.25093198, 0.71911127],
               [0.05182876, 0.9463291, 0.7222756]],
              [[0.56736192, 0.62692889, 0.33814278],
               [0.72362855, 0.12885637, 0.44096788]],
              [[0.12706838, 0.90640269, 0.5126569],
               [0.62920448, 0.24502599, 0.26754067]]])

b = np.array([[[0.4, 0.4, 0.4],
               [0.4, 0.4, 0.4]],
              [[0.4, 0.4, 0.4],
               [0.4, 0.4, 0.4]],
              [[0.4, 0.4, 0.4],
               [0.4, 0.4, 0.4]],
              [[0.4, 0.4, 0.4],
               [0.4, 0.4, 0.4]],
              [[0.4, 0.4, 0.4],
               [0.4, 0.4, 0.4]],
              [[0.4, 0.4, 0.4],
               [0.4, 0.4, 0.4]]])

# these two loops i and j are very inefficient 
for i in range(a.shape[1]):
    for j in range(a.shape[2]):
        print(i, j)
        aij = a[:, i, j]
        bij = b[:, i, j]
        mask = aij <= bij
        split_indices = np.where(mask)[0]
        for subarray in np.split(aij, split_indices + 1):
            if len(subarray) > 3:
                print(subarray[:-1])

现在,这行得通。但是,我的实际数据(数组 a 和 b)的形状约为(500、800、1500),这意味着循环变得有问题(非常昂贵)。

你能想出一种更矢量化它的方法吗?我试图获得一个 3D 蒙版并以 3D 方式拆分,但这会导致沿轴 1 和 2 的拆分大小不相等,这是一个问题(以及 np.split 仅采用 0 或 1-D 索引列表的原因.

标签: pythonarrayspython-3.xnumpyvectorization

解决方案


(1,2)您可以遍历axis,而不是循环遍历axis ,0以迭代地检查掩码是否适用于连续元素。根据您最终想要实现的目标,最多需要n迭代(n您拥有的最长序列在哪里)。

例如,如果您只是想识别至少 3 个连续元素的任何有效序列的起始元素axis=0mask您可以这样做:

mask = a > b
runs = np.zeros_like(mask, dtype=bool)
runs[:-2] = mask[:-2] & mask[1:-1] & mask[2:]

在您的示例中,这会产生:

>>> runs
array([[[ True, False, False],
    [False, False,  True]],
   [[ True, False, False],
    [False, False,  True]],
   [[ True, False, False],
    [False, False,  True]],
   [[False, False, False],
    [False, False, False]],
   [[False, False, False],
    [False, False, False]],
   [[False, False, False],
    [False, False, False]]])

在这里,对至少长度为 3 的有效序列的所有起始元素runs求值True。由于我不确定一旦确定它们后你想做什么,我就留在这里。但希望很清楚如何从这里开始概括您正在尝试做的任何事情。


推荐阅读