首页 > 解决方案 > 如果每行的最大值小于特定阈值,如何修改 numpy 数组的最后一列?

问题描述

我在 Python 中使用 numpy,我对它很陌生。我有一个问题,我一直试图自己解决,尽管我无法弄清楚。

假设我有一个 numpy 数组,它保存从 0 到 1 的概率,并且有 4 列,如下所示:

print(probabilities) 
[[0.99859875 0.00126745 0.00011843 0.00001534 ]
 [0.9995454  0.00036321 0.00007128 0.00002014 ]
 [0.6         0.2       0.1        0.1        ]]

print(probabilities.shape) 
(3, 4)

我想做的是检查行的最大值是否小于特定阈值(在这种情况下假设为 0.7),如果是,则创建一个新列,其概率为 1 并使其余列具有该行的概率为 0。应用于上述示例的结果应该是:

 print(new_probabilities) 
    [[0.99859875 0.00126745 0.00011843 0.00001534  0.0]
     [0.9995454  0.00036321 0.00007128 0.00002014  0.0]
     [0.0        0.0        0.0        0.0         1.0]]
    
    print(probabilities.shape) 
    (3, 5)

有谁知道如何做到这一点?到目前为止,我刚刚创建了一个新的 numpy 数组new_probabilitiesprobabilities如下所示:

new_probabilities = np.zeros((probabilities.shape[0],5))
new_probabilities[:,:-1] = probabilities
print(new_probabilities.shape)
(3,5)

但我不确定如何从这里继续。我真的很感激你能提供的任何帮助!!

标签: pythonnumpy

解决方案


首先获取每行的最大值:

 maxima = np.max(probabilities, axis=1)

然后制作一个面具,告诉您哪些小于您的阈值:

threshold = 0.7
mask = maxima < threshold

现在只填写new_probabilities与掩码不匹配的行:

new_probabilities = np.empty((probabilities.shape[0], probabilities.shape[1] + 1), probabilities.dtype)
new_probabilities[~mask, :-1] = probabilities[~mask]

现在您可以填写最后一列作为掩码:

new_probabilities[:, -1] = mask

推荐阅读