首页 > 解决方案 > 根据字典中的存在过滤numpy数组

问题描述

我有一个 numpy ndarray 如下:

import numpy as np
x = np.array([[1, 2, 1], [4, 5, 7], [3, 2, 3]])

我有一本字典,其中保留了一些类 ID,如下所示:

k = {1: None, 2: None, 3: None}

现在,该 numpy 数组的最后一列包含类 ID。所以我想做的是根据字典中是否存在类ID来过滤numpy数组。因此,过滤该输入数组将给出第 1 行和第 3 行,因为7它不在字典中。

所以我得到类列:

cls = x[:, -1]

现在,我不知道如何使用它来过滤x数组而不循环遍历它并创建另一个数组。

标签: pythonnumpy

解决方案


这是一种方法numpy.in1d

keys = list(k.keys())
res = x[np.in1d(x[:, -1], keys)]

print(res)

[[1 2 1]
 [3 2 3]]

推荐阅读