python - 在 2d numpy 数组的每行中查找 N 个最高值
问题描述
我有一个 2d numpy 数组a
:
a = np.array(range(0,25)).reshape(5,5)
---
[[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 18 19]
[20 21 22 23 24]]
我想找到每行的最大 N 值并将它们替换为 100。我以缓慢的方式执行此操作:
N = 2
idx = a.argsort()
for i in range(a.shape[0]):
a[i,idx[i][::-1][0:N]] = 100
print(a)
---
[[ 0 1 2 100 100]
[ 5 6 7 100 100]
[ 10 11 12 100 100]
[ 15 16 17 100 100]
[ 20 21 22 100 100]]
实际上我的矩阵的形状是 6000*6000。如何以更好的方式做到这一点?喜欢申请?
解决方案
你可以argpartition
在这里使用:
N=2
ix = a.argpartition(-N)[:,-N:]
a[np.arange(a.shape[0])[:,None], ix] = 100
print(a)
array([[ 0, 1, 2, 100, 100],
[ 5, 6, 7, 100, 100],
[ 10, 11, 12, 100, 100],
[ 15, 16, 17, 100, 100],
[ 20, 21, 22, 100, 100]])
推荐阅读
- javascript - Promise 引发“未处理的承诺拒绝”错误的奇怪行为
- python - 多自变量的最佳建模技术
- symfony - 日志实体创建
- android - Google Place 自动完成
- css - 如何根据自定义元素属性设置样式:root
- android -
颤振医生时发表评论--android-licenses -v - ios - 如何在 3 秒后显示新的视图控制器
- python - TypeError: loadshortlink() 为参数“shortlink”获得了多个值
- python - Python3 Sqlite3 - 如何正确转义执行脚本?
- javascript - 如何使用 JavaScript 有条件地将过滤器应用于 Firestore 查询