首页 > 解决方案 > 沿第三维从 Numpy ndarray 中删除并提取最小值

问题描述

我有一个形状数组,(3, 4, 3)它代表 3x4 像素和 3 个颜色通道(最后一个索引)的图像。我的目标是留下一个(3, 4, 2)数组,其中每个像素都删除了最低颜色通道。我可以进行逐像素迭代,但这将非常耗时。使用np.argmin我可以轻松提取最小值的索引,以便我知道哪个颜色通道包含每个像素的最小值。但是,我找不到一种巧妙的索引方法来删除这些值,以便留下一个(3, 4, 2)数组。

此外,我还尝试通过使用类似array[:, :, indexMin)]但无法获得(3, 4)包含每个像素的最小通道值的所需形状数组来选择最小值。我知道有一个功能np.amin,但它会让我更好地理解 Numpy 数组。下面提供了我的代码结构的最小示例:

import numpy as np

arr = [[1, 2, 3, 4],
       [5, 6, 7, 8],
       [9, 10, 11, 12]]
array = np.zeros((3, 4, 3))
array[:,:,0] = arr
array[:,:,1] = np.fliplr(arr)
array[:,:,2] = np.flipud(arr)
indexMin = np.argmin(array, axis=2)

标签: pythonarraysnumpy

解决方案


您需要正确广播到所需输出形状的数组。您可以使用以下方法添加缺少的维度np.expand_dims

index = np.expand_dims(np.argmin(array, axis=2), axis=2)

这使得设置或提取要删除的元素变得容易:

index = list(np.indices(array.shape, sparse=True))
index[-1] = np.expand_dims(np.argmin(array, axis=2), axis=2)
minima = array[tuple(index)]

np.indiceswithsparse=True返回一组范围,以在每个维度中正确广播索引。一个更好的选择是使用np.take_along_axis

index = np.expand_dims(np.argmin(array, axis=2), axis=2)
minima = np.take_along_axis(array, index, axis=2)

您可以使用这些结果来创建掩码,例如np.put_along_axis

mask = np.ones(array.shape, dtype=bool)
np.put_along_axis(mask, index, 0, axis=2)

使用掩码索引数组可为您提供:

result = array[mask].reshape(*array.shape[:2], -1)

重塑有效,因为您的像素存储在最后一个维度中,该维度应该在内存中是连续的。这意味着掩码正确地删除了三个元素中的一个,因此在内存中正确排序。这在掩蔽操作中并不常见。

另一种选择是使用np.deleteraveled 数组和np.ravel_multi_index

i = np.indices(array.shape[:2], sparse=True)
index = np.ravel_multi_index((*i, np.argmin(array, axis=2)), array.shape)
result = np.delete(array.ravel(), index).reshape(*array.shape[:2], -1)

只是为了好玩,您可以利用每个像素只有三个元素的事实来创建要保留的元素的完整索引。这个想法是所有三个指数的总和是3。因此,3 - np.argmin(array, axis=2) - np.argmax(array, axis=2)是中位元。如果您堆叠中位数和最大值,您会得到一个类似于sort给您的索引:

amax = np.argmax(array, axis=2)
amin = np.argmin(array, axis=2)
index = np.stack((np.clip(3 - (amin + amax), 0, 2), amax), axis=2)
result = np.take_along_axis(array, index, axis=2)

调用 tonp.clip是处理所有元素都相等的情况所必需的,在这种情况下两者都argmax返回argmin零。

定时

比较方法:

def remove_min_indices(array):
    index = list(np.indices(array.shape, sparse=True))
    index[-1] = np.expand_dims(np.argmin(array, axis=2), axis=2)
    mask = np.ones(array.shape, dtype=bool)
    mask[tuple(index)] = False
    return array[mask].reshape(*array.shape[:2], -1)

def remove_min_put(array):
    mask = np.ones(array.shape, dtype=bool)
    np.put_along_axis(mask, np.expand_dims(np.argmin(array, axis=2), axis=2), 0, axis=2)
    return array[mask].reshape(*array.shape[:2], -1)

def remove_min_delete(array):
    i = np.indices(array.shape[:2], sparse=True)
    index = np.ravel_multi_index((*i, np.argmin(array, axis=2)), array.shape)
    return np.delete(array.ravel(), index).reshape(*array.shape[:2], -1)

def remove_min_sort_c(array):
    return np.sort(array, axis=2)[..., 1:]

def remove_min_sort_i(array):
    array.sort(axis=2)
    return array[..., 1:]

def remove_min_median(array):
    amax = np.argmax(array, axis=2)
    amin = np.argmin(array, axis=2)
    index = np.stack((np.clip(3 - (amin + amax), 0, 2), amax), axis=2)
    return np.take_along_axis(array, index, axis=2)

测试了类似array = np.random.randint(10, size=(N, N, 3), dtype=np.uint8)for Nin的数组{100, 1K, 10K, 100K, 1M}

  N  |   IND   |   PUT   |   DEL   |   S_C   |   S_I   |   MED   |
-----+---------+---------+---------+---------+---------+---------+
100  | 648. µs | 658. µs | 765. µs | 497. µs | 246. µs | 905. µs |
  1K | 67.9 ms | 68.1 ms | 85.7 ms | 51.7 ms | 24.0 ms | 123. ms |
 10K | 6.86 s  | 6.86 s  | 8.72 s  | 5.17 s  | 2.39 s  | 13.2 s  |
-----+---------+---------+---------+---------+---------+---------+

N^2正如预期的那样,时间比例为 。排序返回与其他方法不同的结果,但显然是最有效的。对于较大的数组,屏蔽put_along_axis似乎是更有效的方法,而对于较小的数组,原始索引似乎更有效。


推荐阅读