首页 > 解决方案 > 如何在 numpy 数组上使用 argmax

问题描述

我在 Python 中使用 Keras 深度学习。我有一个具有这种形状的 numpy 数组: [5, 30, 30, 30, 65] 我想沿第 5 维(65)执行 argmax,将每个第 5 个数组中最大值的索引设置为 1 和所有其他为0。对于上下文,该数组由生成器模型输出,并且我使用长度为65的最后一个维度作为单热编码。我该怎么做呢?

标签: pythonmachine-learningnumpy-ndarray

解决方案


argmax在第 5 维上执行操作,您可以axis使用从 0 开始的计数指定参数。

因此,要获得argmax第 5 维,您将使用:
idx = np.argmax(arr, axis=4)[0, 0, 0, 0]

然后,要将第 5 维数组转换为 one-hot 编码,您可以将该数组设置为一个np.zeros数组,然后将idx(从上面)的值更新为 1,如下所示:

arr[a, b, c, d] = np.zeros((65,))
arr[a, b, c, d, idx] = 1

其中a, b, c, 和d是第 1 到第 4 维的索引。


推荐阅读