python - numpy 中是否有一种方法可以验证一个数组是否包含在另一个数组中?
问题描述
我想验证一个 numpy 数组是否是另一个数组中的连续序列。
例如
a = np.array([1,2,3,4,5,6,7])
b = np.array([3,4,5])
c = np.array([2,3,4,6])
预期的结果是:
is_sequence_of(b, a) # should return True
is_sequence_of(c, a) # should return False
我想知道是否有一个 numpy 方法可以做到这一点。
解决方案
方法#1
我们可以使用一个np.searchsorted
-
def isin_seq(a,b):
# Look for the presence of b in a, while keeping the sequence
sidx = a.argsort()
idx = np.searchsorted(a,b,sorter=sidx)
idx[idx==len(a)] = 0
ssidx = sidx[idx]
return (np.diff(ssidx)==1).all() & (a[ssidx]==b).all()
请注意,这假设输入数组没有重复项。
样品运行 -
In [42]: isin_seq(a,b) # search for the sequence b in a
Out[42]: True
In [43]: isin_seq(c,b) # search for the sequence b in c
Out[43]: False
方法#2
另一个与skimage.util.view_as_windows
-
from skimage.util import view_as_windows
def isin_seq_v2(a,b):
return (view_as_windows(a,len(b))==b).all(1).any()
方法#3
这也可以被认为是模板匹配问题,因此,对于 int 数字,我们可以使用 OpenCV 的内置函数template-matching
:(cv2.matchTemplate
灵感来自this post
),就像这样 -
import cv2
from cv2 import matchTemplate as cv2m
def isin_seq_v3(arr,seq):
S = cv2m(arr.astype('uint8'),seq.astype('uint8'),cv2.TM_SQDIFF)
return np.isclose(S,0).any()
方法#4
我们的方法可以受益于short-circuiting
基于的方法。所以,我们将使用一个 withnumba
来提高性能,就像这样 -
from numba import njit
@njit
def isin_seq_numba(a,b):
m = len(a)
n = len(b)
for i in range(m-n+1):
for j in range(n):
if a[i+j]!=b[j]:
break
if j==n-1:
return True
return False
推荐阅读
- javascript - 如何替换列表对象中的一个值并使用 angularjs 将相同的新值分配给相同的对象列表?
- jquery - 将 html jquery 联系表单转换为 asp.net
- php - 按 ID 查询帖子总是返回最新的帖子(WordPress)
- swift - 使用已经存在的覆盖函数 layerClass 向 UIView 类添加阴影
- sql-server - 限制 SQL Server 事务日志
- github - 在存储库中未检测到 GitHub Actions 工作流
- python - Python程序子类化一个PyQt5窗口不能在函数中设置window.title
- graph - 如何使用 Microsoft Graph 提供完整的邮箱访问权限?
- c# - 在文本文件中记录 Mongo 驱动程序查询性能
- java - ApplicationContext.getBeansOfType 改变返回的映射 - WebFluxTest SpringBoot