首页 > 解决方案 > 更快的代码来计算具有循环(周期性)边界条件的 numpy 数组中的点之间的距离

问题描述

我知道如何使用 scipy.spatial.distance.cdist 计算数组中点之间的欧几里得距离

类似于这个问题的答案: 计算矩阵中的一个点与所有其他点之间的距离

但是,我想在假设循环边界条件的情况下进行计算,例如,在这种情况下,点 [0,0] 与点 [0,n-1] 的距离为 1,而不是 n-1 的距离。(然后,我将为目标细胞阈值距离内的所有点制作一个蒙版,但这不是问题的核心)。

我能想到的唯一方法是重复计算 9 次,域索引在 x、y 和 x&y 方向上添加/减去 n,然后堆叠结果并在 9 个切片中找到最小值。为了说明需要 9 次重复,我整理了一个简单的示意图,其中只有 1 个 J 点,并用圆圈标记,它显示了一个示例,在这种情况下,由三角形标记的单元格在域中的最近邻居反映到左上角。

在此处输入图像描述

这是我使用 cdist 为此开发的代码:

import numpy as np
from scipy import spatial
    
n=5 # size of 2D box (n X n points)
np.random.seed(1) # to make reproducible
a=np.random.uniform(size=(n,n)) 
i=np.argwhere(a>-1)  # all points, for each loc we want distance to nearest J 
j=np.argwhere(a>0.85) # set of J locations to find distance to.

# this will be used in the KDtree soln 
global maxdist
maxdist=2.0

def dist_v1(i,j):
    dist=[]
    # 3x3 search required for periodic boundaries.
    for xoff in [-n,0,n]:
        for yoff in [-n,0,n]:
            jo=j.copy()
            jo[:,0]-=xoff
            jo[:,1]-=yoff
            dist.append(np.amin(spatial.distance.cdist(i,jo,metric='euclidean'),1)) 
    dist=np.amin(np.stack(dist),0).reshape([n,n])
    return(dist)

这有效,并产生例如:

print(dist_v1(i,j))


[[1.41421356 1.         1.41421356 1.41421356 1.        ]
 [2.23606798 2.         1.41421356 1.         1.41421356]
 [2.         2.         1.         0.         1.        ]
 [1.41421356 1.         1.41421356 1.         1.        ]
 [1.         0.         1.         1.         0.        ]]

零显然标记了 J 点,并且距离是正确的(这个编辑纠正了我之前不正确的尝试)。

请注意,如果您更改最后两行以堆叠原始距离,然后只使用一个最小值,如下所示:

def dist_v2(i,j):
    dist=[]
    # 3x3 search required for periodic boundaries.
    for xoff in [-n,0,n]:
        for yoff in [-n,0,n]:
            jo=j.copy()
            jo[:,0]-=xoff
            jo[:,1]-=yoff
            dist.append(spatial.distance.cdist(i,jo,metric='euclidean')) 
    dist=np.amin(np.dstack(dist),(1,2)).reshape([n,n])
    return(dist)

对于较小的 n (<10) 它更快,但对于较大的数组 (n>10)则要慢得多

...但是无论哪种方式,对于我的大型数组(N=500 和 J 点数在 70 左右)来说它都很,这个搜索占用了大约 99% 的计算时间,(而且使用循环也有点难看) - 有没有更好/更快的方法?

我想到的其他选择是:

  1. scipy.spatial.KDTree.query_ball_point

通过进一步搜索,我发现有一个函数 scipy.spatial.KDTree.query_ball_point直接计算我的 J 点半径内的坐标,但它似乎没有任何使用周期性边界的设施,所以我假设仍然需要以某种方式使用 3x3 循环、堆栈,然后像上面那样使用 amin,所以我不确定这是否会更快。

我使用这个函数编写了一个解决方案,而不用担心周期性边界条件(即这不能回答我的问题)

def dist_v3(n,j):
    x, y = np.mgrid[0:n, 0:n]
    points = np.c_[x.ravel(), y.ravel()]
    tree=spatial.KDTree(points)
    mask=np.zeros([n,n])
    for results in tree.query_ball_point((j), maxdist):
        mask[points[results][:,0],points[results][:,1]]=1
    return(mask)

也许我没有以最有效的方式使用它,但这已经和我的基于 cdist 的解决方案一样慢,即使没有周期性边界。在两个 cdist 解决方案中包括 mask 函数,即在这些函数中替换return(dist)with return(np.where(dist<=maxdist,1,0)),然后使用 timeit,我得到以下 n=100 的时间:

from timeit import timeit

print("cdist v1:",timeit(lambda: dist_v1(i,j), number=3)*100)
print("cdist v2:",timeit(lambda: dist_v2(i,j), number=3)*100)
print("KDtree:", timeit(lambda: dist_v3(n,j), number=3)*100)

cdist v1: 181.80927299981704
cdist v2: 554.8205785999016
KDtree: 605.119637199823
  1. 为 [0,0] 的设定距离内的点创建一个相对坐标数组,然后手动循环 J 点,用这个相对点列表设置掩码 - 这样做的好处是“相对距离”计算仅执行一次(我的 J 点每个时间步都会改变),但我怀疑循环会很慢。

  2. 为 2D 域中的每个点预先计算一组掩码,因此在模型集成的每个时间步中,我只需选择 J 点的掩码并应用。这将使用大量内存(与 n^4 成正比)并且可能仍然很慢,因为您需要循环 J 点以组合掩码。

标签: pythonperformancenumpyscipydistance

解决方案


我将从图像处理的角度展示另一种方法,您可能会感兴趣,无论它是否最快。为方便起见,我只为一个奇怪的n.

nxn与其考虑一组点,不如i取而代之的是nxn盒子。我们可以将其视为二值图像。让每个点j成为该图像中的一个正像素。这n=5看起来像:

二值图像

现在让我们考虑一下图像处理中的另一个概念:膨胀。对于任何输入像素,如果它的 中有一个正像素neighborhood,则输出像素将为 1。这个邻域由所谓的Structuring Element: 一个布尔内核定义,其中将显示要考虑的邻居。

以下是我为这个问题定义 SE 的方式:

Y, X = np.ogrid[-n:n+1, -n:n+1]
SQ = X*X+Y*Y

H = SQ == r

直观地说,H 是一个掩码,表示“从中心开始满足方程的所有点” x*x+y*y=r。也就是说,H 中的所有点都sqrt(r)与中心相距一定距离。另一个可视化,它会非常清楚:

H1 H2 H4 H5

这是一个不断扩大的像素圈。每个掩码中的每个白色像素都表示与中心像素的距离恰好为 的点sqrt(r)。您也许还可以看出,如果我们迭代地增加 的值r,我们实际上会稳定地覆盖特定位置周围的所有像素位置,最终覆盖整个图像。(某些 r 值没有给出响应,因为对于任何一对点都不存在这样的距离 sqrt(r)。我们跳过那些 r 值——比如 3。)

所以这就是主要算法的作用。

  • 我们将从 0 开始逐步增加值r到某个较高的上限。
  • 在每一步中,如果图像中的任何位置 (x,y) 对膨胀有响应,则意味着在距它正好 sqrt(r) 距离处有一个 j 点!
  • 我们可以多次找到匹配项;我们将只保留第一个匹配项并丢弃更多匹配项以获得积分。我们这样做直到所有像素(x,y)都找到了它们的最小距离/第一次匹配。

所以你可以说这个算法取决于 nxn 图像中唯一距离对的数量。

这也意味着如果你在 j 中的点越来越多,算法实际上会变得更快,这有悖常理!

这种膨胀算法的最坏情况是当您拥有最少数量的点(j 中的一个点)时,因为那时它需要将 r 迭代到一个非常高的值才能从远处的点获得匹配。

在执行方面:

n=5 # size of 2D box (n X n points)
np.random.seed(1) # to make reproducible
a=np.random.uniform(size=(n,n)) 
I=np.argwhere(a>-1)  # all points, for each loc we want distance to nearest J 
J=np.argwhere(a>0.85)

Y, X = np.ogrid[-n:n+1, -n:n+1]
SQ = X*X+Y*Y

point_space = np.zeros((n, n))
point_space[J[:,0], J[:,1]] = 1


C1 = point_space[:, :n//2]
C2 = point_space[:, n//2+1:]
C = np.hstack([C2, point_space, C1])

D1 = point_space[:n//2, :]
D2 = point_space[n//2+1:, :]
D2_ = np.hstack([point_space[n//2+1:, n//2+1:],D2,point_space[n//2+1:, :n//2]])
D1_ = np.hstack([point_space[:n//2:, n//2+1:],D1,point_space[:n//2, :n//2]])
D = np.vstack([D2_, C, D1_])
p = (3*n-len(D))//2
D = np.pad(D, (p,p), constant_values=(0,0))

plt.imshow(D, cmap='gray')
plt.title(f'n={n}')

D1 D2

如果您查看 n=5 的图像,您可以知道我做了什么;我只是用它的四个象限填充图像以表示循环空间,然后添加一些额外的零填充以解决最坏情况的搜索边界。

@nb.jit
def dilation(image, output, kernel, N, i0, i1):
    for i in range(i0,i1):
        for j in range(i0, i1):
            a_0 = i-(N//2)
            a_1 = a_0+N
            b_0 = j-(N//2)
            b_1 = b_0+N
            neighborhood = image[a_0:a_1, b_0:b_1]*kernel
            if np.any(neighborhood):
                output[i-i0,j-i0] = 1
    return output

@nb.njit(cache=True)
def progressive_dilation(point_space, out, total, dist, SQ, n, N_):
    for i in range(N_):
        if not np.any(total): 
            break
            
        H = SQ == i

        rows, cols = np.nonzero(H)
        if len(rows) == 0: continue
        
        rmin, rmax = rows.min(), rows.max()
        cmin, cmax = cols.min(), cols.max()

        H_ = H[rmin:rmax+1, cmin:cmax+1]
        
        out[:] = False 
        out = dilation(point_space, out, H_, len(H_), n, 2*n)
        
        idx = np.logical_and(out, total)
        
        for a, b in  zip(*np.where(idx)):
            dist[a, b] = i
        
        total = total * np.logical_not(out)
    return dist

def dilateWrap(D, SQ, n):
    out = np.zeros((n,n), dtype=bool)
    total = np.ones((n,n), dtype=bool)
    dist=-1*np.ones((n,n))
    dist = progressive_dilation(D, out, total, dist, SQ, n, 2*n*n+1)
    return dist

dout = dilateWrap(D, SQ, n)

如果我们可视化dout,我们实际上可以获得距离的惊人视觉表示。

O1 氧气

黑点基本上是存在j个点的位置。最亮的点自然是指离任何 j 最远的点。请注意,我将值保留为平方形式以获得整数图像。实际距离仍然是一平方根。结果与球场算法的输出相匹配。

# after resetting n = 501 and rerunning the first block

N = J.copy()
NE = J.copy()
E = J.copy()
SE = J.copy()
S = J.copy()
SW = J.copy()
W = J.copy()
NW = J.copy()

N[:,1] = N[:,1] - n
NE[:,0] = NE[:,0] - n
NE[:,1] = NE[:,1] - n
E[:,0] = E[:,0] - n
SE[:,0] = SE[:,0] - n
SE[:,1] = SE[:,1] + n
S[:,1] = S[:,1] + n
SW[:,0] = SW[:,0] + n
SW[:,1] = SW[:,1] + n
W[:,0] = W[:,0] + n
NW[:,0] = NW[:,0] + n
NW[:,1] = NW[:,1] - n

def distBP(I,J):
    tree = BallTree(np.concatenate([J,N,E,S,W,NE,SE,SW,NW]), leaf_size=15, metric='euclidean')
    dist = tree.query(I, k=1, return_distance=True)
    minimum_distance = dist[0].reshape(n,n)
    return minimum_distance

print(np.array_equal(distBP(I,J), np.sqrt(dilateWrap(D, SQ, n))))

出去:

True

现在进行时间检查,n=501。

from timeit import timeit
nl=1
print("ball tree:",timeit(lambda: distBP(I,J),number=nl))
print("dilation:",timeit(lambda: dilateWrap(D, SQ, n),number=nl))

出去:

ball tree: 1.1706031339999754
dilation: 1.086665302000256

我会说它们大致相等,尽管膨胀的边缘很小。事实上,膨胀仍然缺少平方根运算,让我们补充一下。

from timeit import timeit
nl=1
print("ball tree:",timeit(lambda: distBP(I,J),number=nl))
print("dilation:",timeit(lambda: np.sqrt(dilateWrap(D, SQ, n)),number=nl))

出去:

ball tree: 1.1712950239998463
dilation: 1.092416919000243

平方根对时间的影响基本上可以忽略不计。

现在,我之前说过,当 j 中实际上有更多点时,膨胀会变得更快。所以让我们增加j中的点数。

n=501 # size of 2D box (n X n points)
np.random.seed(1) # to make reproducible
a=np.random.uniform(size=(n,n)) 
I=np.argwhere(a>-1)  # all points, for each loc we want distance to nearest J 
J=np.argwhere(a>0.4) # previously a>0.85

现在查看时间:

from timeit import timeit
nl=1
print("ball tree:",timeit(lambda: distBP(I,J),number=nl))
print("dilation:",timeit(lambda: np.sqrt(dilateWrap(D, SQ, n)),number=nl))

出去:

ball tree: 3.3354218500007846
dilation: 0.2178608220001479

球树实际上变慢了,而膨胀变快了!这是因为如果有很多 j 个点,我们可以通过几次膨胀重复快速找到所有距离。我发现这种效果相当有趣——通常你会认为运行时间会随着点数的增加而变差,但这里的情况正好相反。

相反,如果我们减少 j,我们会看到膨胀变慢:

#Setting a>0.9
print("ball tree:",timeit(lambda: distBP(I,J),number=nl))
print("dilation:",timeit(lambda: np.sqrt(dilateWrap(D, SQ, n)),number=nl))

出去:

ball tree: 1.010353464000218
dilation: 1.4776274510004441

我认为我们可以有把握地得出结论,卷积或基于内核的方法将在这个特定问题上提供更好的收益,而不是基于对或点或基于树的方法。

最后,我在开头提到了它,我会再次提到它:整个实现只考虑 n 的奇数值;我没有耐心为偶数 n 计算适当的填充。(如果您熟悉图像处理,您可能以前曾遇到过这种情况:奇数尺寸更容易。)

这也可能会进一步优化,因为我只是偶尔涉足 numba。


推荐阅读