python - 更快的代码来计算具有循环(周期性)边界条件的 numpy 数组中的点之间的距离
问题描述
我知道如何使用 scipy.spatial.distance.cdist 计算数组中点之间的欧几里得距离
类似于这个问题的答案: 计算矩阵中的一个点与所有其他点之间的距离
但是,我想在假设循环边界条件的情况下进行计算,例如,在这种情况下,点 [0,0] 与点 [0,n-1] 的距离为 1,而不是 n-1 的距离。(然后,我将为目标细胞阈值距离内的所有点制作一个蒙版,但这不是问题的核心)。
我能想到的唯一方法是重复计算 9 次,域索引在 x、y 和 x&y 方向上添加/减去 n,然后堆叠结果并在 9 个切片中找到最小值。为了说明需要 9 次重复,我整理了一个简单的示意图,其中只有 1 个 J 点,并用圆圈标记,它显示了一个示例,在这种情况下,由三角形标记的单元格在域中的最近邻居反映到左上角。
这是我使用 cdist 为此开发的代码:
import numpy as np
from scipy import spatial
n=5 # size of 2D box (n X n points)
np.random.seed(1) # to make reproducible
a=np.random.uniform(size=(n,n))
i=np.argwhere(a>-1) # all points, for each loc we want distance to nearest J
j=np.argwhere(a>0.85) # set of J locations to find distance to.
# this will be used in the KDtree soln
global maxdist
maxdist=2.0
def dist_v1(i,j):
dist=[]
# 3x3 search required for periodic boundaries.
for xoff in [-n,0,n]:
for yoff in [-n,0,n]:
jo=j.copy()
jo[:,0]-=xoff
jo[:,1]-=yoff
dist.append(np.amin(spatial.distance.cdist(i,jo,metric='euclidean'),1))
dist=np.amin(np.stack(dist),0).reshape([n,n])
return(dist)
这有效,并产生例如:
print(dist_v1(i,j))
[[1.41421356 1. 1.41421356 1.41421356 1. ]
[2.23606798 2. 1.41421356 1. 1.41421356]
[2. 2. 1. 0. 1. ]
[1.41421356 1. 1.41421356 1. 1. ]
[1. 0. 1. 1. 0. ]]
零显然标记了 J 点,并且距离是正确的(这个编辑纠正了我之前不正确的尝试)。
请注意,如果您更改最后两行以堆叠原始距离,然后只使用一个最小值,如下所示:
def dist_v2(i,j):
dist=[]
# 3x3 search required for periodic boundaries.
for xoff in [-n,0,n]:
for yoff in [-n,0,n]:
jo=j.copy()
jo[:,0]-=xoff
jo[:,1]-=yoff
dist.append(spatial.distance.cdist(i,jo,metric='euclidean'))
dist=np.amin(np.dstack(dist),(1,2)).reshape([n,n])
return(dist)
对于较小的 n (<10) 它更快,但对于较大的数组 (n>10)则要慢得多
...但是无论哪种方式,对于我的大型数组(N=500 和 J 点数在 70 左右)来说它都很慢,这个搜索占用了大约 99% 的计算时间,(而且使用循环也有点难看) - 有没有更好/更快的方法?
我想到的其他选择是:
- scipy.spatial.KDTree.query_ball_point
通过进一步搜索,我发现有一个函数 scipy.spatial.KDTree.query_ball_point直接计算我的 J 点半径内的坐标,但它似乎没有任何使用周期性边界的设施,所以我假设仍然需要以某种方式使用 3x3 循环、堆栈,然后像上面那样使用 amin,所以我不确定这是否会更快。
我使用这个函数编写了一个解决方案,而不用担心周期性边界条件(即这不能回答我的问题)
def dist_v3(n,j):
x, y = np.mgrid[0:n, 0:n]
points = np.c_[x.ravel(), y.ravel()]
tree=spatial.KDTree(points)
mask=np.zeros([n,n])
for results in tree.query_ball_point((j), maxdist):
mask[points[results][:,0],points[results][:,1]]=1
return(mask)
也许我没有以最有效的方式使用它,但这已经和我的基于 cdist 的解决方案一样慢,即使没有周期性边界。在两个 cdist 解决方案中包括 mask 函数,即在这些函数中替换return(dist)
with return(np.where(dist<=maxdist,1,0))
,然后使用 timeit,我得到以下 n=100 的时间:
from timeit import timeit
print("cdist v1:",timeit(lambda: dist_v1(i,j), number=3)*100)
print("cdist v2:",timeit(lambda: dist_v2(i,j), number=3)*100)
print("KDtree:", timeit(lambda: dist_v3(n,j), number=3)*100)
cdist v1: 181.80927299981704
cdist v2: 554.8205785999016
KDtree: 605.119637199823
为 [0,0] 的设定距离内的点创建一个相对坐标数组,然后手动循环 J 点,用这个相对点列表设置掩码 - 这样做的好处是“相对距离”计算仅执行一次(我的 J 点每个时间步都会改变),但我怀疑循环会很慢。
为 2D 域中的每个点预先计算一组掩码,因此在模型集成的每个时间步中,我只需选择 J 点的掩码并应用。这将使用大量内存(与 n^4 成正比)并且可能仍然很慢,因为您需要循环 J 点以组合掩码。
解决方案
我将从图像处理的角度展示另一种方法,您可能会感兴趣,无论它是否最快。为方便起见,我只为一个奇怪的n
.
nxn
与其考虑一组点,不如i
取而代之的是nxn
盒子。我们可以将其视为二值图像。让每个点j
成为该图像中的一个正像素。这n=5
看起来像:
现在让我们考虑一下图像处理中的另一个概念:膨胀。对于任何输入像素,如果它的 中有一个正像素neighborhood
,则输出像素将为 1。这个邻域由所谓的Structuring Element
: 一个布尔内核定义,其中将显示要考虑的邻居。
以下是我为这个问题定义 SE 的方式:
Y, X = np.ogrid[-n:n+1, -n:n+1]
SQ = X*X+Y*Y
H = SQ == r
直观地说,H 是一个掩码,表示“从中心开始满足方程的所有点” x*x+y*y=r
。也就是说,H 中的所有点都sqrt(r)
与中心相距一定距离。另一个可视化,它会非常清楚:
这是一个不断扩大的像素圈。每个掩码中的每个白色像素都表示与中心像素的距离恰好为 的点sqrt(r)
。您也许还可以看出,如果我们迭代地增加 的值r
,我们实际上会稳定地覆盖特定位置周围的所有像素位置,最终覆盖整个图像。(某些 r 值没有给出响应,因为对于任何一对点都不存在这样的距离 sqrt(r)。我们跳过那些 r 值——比如 3。)
所以这就是主要算法的作用。
- 我们将从 0 开始逐步增加值
r
到某个较高的上限。 - 在每一步中,如果图像中的任何位置 (x,y) 对膨胀有响应,则意味着在距它正好 sqrt(r) 距离处有一个 j 点!
- 我们可以多次找到匹配项;我们将只保留第一个匹配项并丢弃更多匹配项以获得积分。我们这样做直到所有像素(x,y)都找到了它们的最小距离/第一次匹配。
所以你可以说这个算法取决于 nxn 图像中唯一距离对的数量。
这也意味着如果你在 j 中的点越来越多,算法实际上会变得更快,这有悖常理!
这种膨胀算法的最坏情况是当您拥有最少数量的点(j 中的一个点)时,因为那时它需要将 r 迭代到一个非常高的值才能从远处的点获得匹配。
在执行方面:
n=5 # size of 2D box (n X n points)
np.random.seed(1) # to make reproducible
a=np.random.uniform(size=(n,n))
I=np.argwhere(a>-1) # all points, for each loc we want distance to nearest J
J=np.argwhere(a>0.85)
Y, X = np.ogrid[-n:n+1, -n:n+1]
SQ = X*X+Y*Y
point_space = np.zeros((n, n))
point_space[J[:,0], J[:,1]] = 1
C1 = point_space[:, :n//2]
C2 = point_space[:, n//2+1:]
C = np.hstack([C2, point_space, C1])
D1 = point_space[:n//2, :]
D2 = point_space[n//2+1:, :]
D2_ = np.hstack([point_space[n//2+1:, n//2+1:],D2,point_space[n//2+1:, :n//2]])
D1_ = np.hstack([point_space[:n//2:, n//2+1:],D1,point_space[:n//2, :n//2]])
D = np.vstack([D2_, C, D1_])
p = (3*n-len(D))//2
D = np.pad(D, (p,p), constant_values=(0,0))
plt.imshow(D, cmap='gray')
plt.title(f'n={n}')
如果您查看 n=5 的图像,您可以知道我做了什么;我只是用它的四个象限填充图像以表示循环空间,然后添加一些额外的零填充以解决最坏情况的搜索边界。
@nb.jit
def dilation(image, output, kernel, N, i0, i1):
for i in range(i0,i1):
for j in range(i0, i1):
a_0 = i-(N//2)
a_1 = a_0+N
b_0 = j-(N//2)
b_1 = b_0+N
neighborhood = image[a_0:a_1, b_0:b_1]*kernel
if np.any(neighborhood):
output[i-i0,j-i0] = 1
return output
@nb.njit(cache=True)
def progressive_dilation(point_space, out, total, dist, SQ, n, N_):
for i in range(N_):
if not np.any(total):
break
H = SQ == i
rows, cols = np.nonzero(H)
if len(rows) == 0: continue
rmin, rmax = rows.min(), rows.max()
cmin, cmax = cols.min(), cols.max()
H_ = H[rmin:rmax+1, cmin:cmax+1]
out[:] = False
out = dilation(point_space, out, H_, len(H_), n, 2*n)
idx = np.logical_and(out, total)
for a, b in zip(*np.where(idx)):
dist[a, b] = i
total = total * np.logical_not(out)
return dist
def dilateWrap(D, SQ, n):
out = np.zeros((n,n), dtype=bool)
total = np.ones((n,n), dtype=bool)
dist=-1*np.ones((n,n))
dist = progressive_dilation(D, out, total, dist, SQ, n, 2*n*n+1)
return dist
dout = dilateWrap(D, SQ, n)
如果我们可视化dout,我们实际上可以获得距离的惊人视觉表示。
黑点基本上是存在j个点的位置。最亮的点自然是指离任何 j 最远的点。请注意,我将值保留为平方形式以获得整数图像。实际距离仍然是一平方根。结果与球场算法的输出相匹配。
# after resetting n = 501 and rerunning the first block
N = J.copy()
NE = J.copy()
E = J.copy()
SE = J.copy()
S = J.copy()
SW = J.copy()
W = J.copy()
NW = J.copy()
N[:,1] = N[:,1] - n
NE[:,0] = NE[:,0] - n
NE[:,1] = NE[:,1] - n
E[:,0] = E[:,0] - n
SE[:,0] = SE[:,0] - n
SE[:,1] = SE[:,1] + n
S[:,1] = S[:,1] + n
SW[:,0] = SW[:,0] + n
SW[:,1] = SW[:,1] + n
W[:,0] = W[:,0] + n
NW[:,0] = NW[:,0] + n
NW[:,1] = NW[:,1] - n
def distBP(I,J):
tree = BallTree(np.concatenate([J,N,E,S,W,NE,SE,SW,NW]), leaf_size=15, metric='euclidean')
dist = tree.query(I, k=1, return_distance=True)
minimum_distance = dist[0].reshape(n,n)
return minimum_distance
print(np.array_equal(distBP(I,J), np.sqrt(dilateWrap(D, SQ, n))))
出去:
True
现在进行时间检查,n=501。
from timeit import timeit
nl=1
print("ball tree:",timeit(lambda: distBP(I,J),number=nl))
print("dilation:",timeit(lambda: dilateWrap(D, SQ, n),number=nl))
出去:
ball tree: 1.1706031339999754
dilation: 1.086665302000256
我会说它们大致相等,尽管膨胀的边缘很小。事实上,膨胀仍然缺少平方根运算,让我们补充一下。
from timeit import timeit
nl=1
print("ball tree:",timeit(lambda: distBP(I,J),number=nl))
print("dilation:",timeit(lambda: np.sqrt(dilateWrap(D, SQ, n)),number=nl))
出去:
ball tree: 1.1712950239998463
dilation: 1.092416919000243
平方根对时间的影响基本上可以忽略不计。
现在,我之前说过,当 j 中实际上有更多点时,膨胀会变得更快。所以让我们增加j中的点数。
n=501 # size of 2D box (n X n points)
np.random.seed(1) # to make reproducible
a=np.random.uniform(size=(n,n))
I=np.argwhere(a>-1) # all points, for each loc we want distance to nearest J
J=np.argwhere(a>0.4) # previously a>0.85
现在查看时间:
from timeit import timeit
nl=1
print("ball tree:",timeit(lambda: distBP(I,J),number=nl))
print("dilation:",timeit(lambda: np.sqrt(dilateWrap(D, SQ, n)),number=nl))
出去:
ball tree: 3.3354218500007846
dilation: 0.2178608220001479
球树实际上变慢了,而膨胀变快了!这是因为如果有很多 j 个点,我们可以通过几次膨胀重复快速找到所有距离。我发现这种效果相当有趣——通常你会认为运行时间会随着点数的增加而变差,但这里的情况正好相反。
相反,如果我们减少 j,我们会看到膨胀变慢:
#Setting a>0.9
print("ball tree:",timeit(lambda: distBP(I,J),number=nl))
print("dilation:",timeit(lambda: np.sqrt(dilateWrap(D, SQ, n)),number=nl))
出去:
ball tree: 1.010353464000218
dilation: 1.4776274510004441
我认为我们可以有把握地得出结论,卷积或基于内核的方法将在这个特定问题上提供更好的收益,而不是基于对或点或基于树的方法。
最后,我在开头提到了它,我会再次提到它:整个实现只考虑 n 的奇数值;我没有耐心为偶数 n 计算适当的填充。(如果您熟悉图像处理,您可能以前曾遇到过这种情况:奇数尺寸更容易。)
这也可能会进一步优化,因为我只是偶尔涉足 numba。
推荐阅读
- eclipse - 如何使用 e(fx)clipse 将 JavaFX-Project-Wizzard 添加到 Eclipse 4.9?
- swift - NSTableView 列之间的选项卡并移动到下一行
- r - 如何在 r 中绘制支出与年份的关系
- r - Hotelling T2 的乘法矩阵
- apache-spark - Spark SQL NOT 运算符和 Null 感知谓词子查询不能在嵌套条件中使用
- export - Sqoop 导出时出现错误,如何找出确切的错误?
- angular - 找不到“{Module Name}”的 NgModule 元数据中的错误
- java - 服务器无法解析ajax post参数中的某些内容,这是什么原因造成的?
- javascript - 为什么打字稿会为两个对象类型数组引发错误,但不会在两者的形状上都出现错误?
- java - 自定义 ArrayAdapter 上的 getFilter() 不起作用