首页 > 解决方案 > 如何按行获取 numpy 数组的非零元素的数量?

问题描述

我想找到所有条目都小于 1e-6 或非零值的数量小于 3 的行的索引。这样的事情会很好:

import numpy as np

prob = np.random.rand(15, 500)
all_zero = np.where(prob.max(1) < 1e-6 | np.nonzero(prob, axis=1) < 3) 

标签: pythonnumpy

解决方案


我试图衡量迄今为止提出的解决方案的执行时间:
基准数据:

prob = np.random.rand(10000, 500)

@Massifox的解决方案与列表:

%%timeit
[i for i, val in enumerate(prob>1e-6)if val.sum() < 3]
# 39.5 ms ± 1.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

@Massifox的解决方案只有 numpy:

%%timeit
np.where(np.sum(prob>1e-6, axis=1) < 3)
# 9.92 ms ± 199 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

@a_guest的解决方案:

%%timeit
all_zero = np.logical_or(prob.max(axis=1) < 1e-6, np.sum(prob != 0, axis=1) < 3)
np.where(all_zero)
# 13.9 ms ± 150 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

最有效的解决方案似乎是第二种。


推荐阅读