首页 > 解决方案 > Numpy max 仅在对数组的第一个元素上

问题描述

我有一个由元组组成的多维 numpy 数组,如下所示:

[[(0.56, 1),(0.25, 4), ...],[(0.11, 9), ...], ...]

每个元组的第二个元素是索引引用。我想提取每行第一个值最高的元组。有没有办法用 numpy max 实现这一点?

我尝试过的一件事是使用如下的轴参数:

np.max(my_array, axis=0)

但这会在未保留索引引用的对周围进行洗牌。例如,上面示例中的第一行将显示类似的[(0.56,4), ...]内容,而我希望它显示[(0.56,1), ...]

标签: pythonarraysnumpy

解决方案


不要在 numpy 数组中使用元组。将其全部转换为最后一维为 2 的 numpy 数组:

>>> a = np.array([[(0.56, 1), (0.25, 4)],[(0.11, 9), (0.19, 5)]])
>>> a.shape
(2, 2, 2)

然后:

>>> highest_val_per_row = np.argmax(a[:,:,0], axis=1)
>>> a[np.arange(a.shape[0]), highest_val_per_row]
array([[0.56, 1.  ],
       [0.19, 5.  ]])

推荐阅读