numpy - 计算两个numpy数组的向量之间的距离
问题描述
我有两个尺寸为 S x F 的 numpy 数组R 和尺寸为 N x M x F的W。具体让我们分配以下值,,,,N = 5
M = 7
F = 3
S = 4
数组R包含S = 4
具有F = 3
特征的样本集合。每行代表一个样本,每一行代表一个特征。因此R[0]
是第一个样本,R[1]
第二个样本继续。每个R[i-th]
条目都包含F
元素,举例来说R[0] = np.array([1, 4, -2])
。
这是一个初始化所有这些值的小片段,考虑到 MWE
import numpy as np
# Size of Map (rows, columns)
N, M = 5, 7
# Number of features
F = 3
# Sample size
S = 4
np.random.seed(13)
R = np.random.randint(0, 10, size=(S, F))
W = np.random.randint(-4, 5, size=(N, M, F))
我们还可以看到numpy 数组 W的给定“深度线”,作为一个向量,也与数组R的每一行具有相同的维度(这很容易注意到两个数组的最后一个维度的大小)。有了它,我可以访问W[2, 3]
和获取np.array([ 2, 2, -1 ])
(这里的值只是示例)。
我创建了一个简单的函数来计算给定向量r到矩阵W的每个“深度线”的距离,并将W深度线的最近元素的位置返回给r
def nearest_vector_matrix_naive(r, W):
delta = np.zeros((N,M), dtype=int)
for i in range(N):
for j in range(M):
norm = 0
for k in range(F):
norm += (r[k] - W[i,j,k])**2
delta[i,j] = norm
norm = 0
win_idx = np.unravel_index(np.argmin(delta, axis=None), delta.shape)
return win_idx
当然这是一种非常幼稚的方法,我可以进一步优化下面的代码,获得巨大的性能提升。
def nearest_vector_matrix(r, W):
delta = np.sum((W[:,:] - r)**2, axis=2)
return np.unravel_index(np.argmin(delta, axis=None), delta.shape)
我可以简单地使用这个功能
nearest_idx = nearest_vector_matrix(R[0], W)
# Returns the nearest vector in W to R[0]
W[nearest_idx]
由于我有带有一堆样本的数组R ,因此我使用以下代码段来计算最接近样本数组的向量:
def nearest_samples_matrix(R, W):
DELTA = np.zeros((R.shape[0],2))
for idx, r in enumerate(R):
delta = np.sum((W[:,:] - r)**2, axis=2)
DELTA[idx] = np.unravel_index(np.argmin(delta, axis=None), delta.shape)
return DELTA
此函数返回一个包含S行(S是样本数)的二维索引的数组。那就是 DELTA 有(S, 2)
形状(总是)。
我想知道如何替换for
内部的循环(例如用于广播)nearest_samples_matrix
以进一步提高代码执行性能?
我不知道该怎么做。(除了我能够在第一种情况下做到这一点)
解决方案
最佳解决方案取决于数组的输入大小
对于低维问题 dim<20 或更小,kdtree 方法通常是要走的路。关于这个主题有很多答案,例如。我几周前写的一篇。
如果问题的维度太高,您可以切换到蛮力算法。以下两种算法都比您的优化方法快得多,但在较大的输入大小和低维问题上,比 kdtree 方法 O(log(n)) 而不是 O(n^2) 慢得多。
蛮力1
以下示例使用此处描述的算法。它在大维问题上非常快,因为大部分计算都是在高度优化的矩阵-矩阵乘法算法中完成的。缺点是高内存使用(所有距离都在一个函数调用中计算)和精度问题,因为更容易出错的计算方法。
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
def nearest_samples_matrix_2(R,W):
R_Temp=R
W_Temp=W.reshape(-1,W.shape[2])
dist=euclidean_distances(R_Temp, W_Temp)
ind_1,ind_2=np.unravel_index(np.argmin(dist,axis=1),shape=(W.shape[0],W.shape[1]))
return np.vstack((ind_1,ind_2)).T
蛮力2
这与您的幼稚方法非常相似,但使用 JIT-Compiler (Numba) 来获得良好的性能。临时数组不是必需的,精度应该很好(只要不发生溢出)。在更大的输入尺寸上还有进一步优化(循环平铺)的空间。
import numpy as np
import numba as nb
#parallelization is only beneficial on larger input data
@nb.njit(fastmath=True,parallel=True,cache=True)
def nearest_samples_matrix_3(r, W):
ind_i=0
ind_j=0
out=np.empty((r.shape[0],2),dtype=np.int64)
for x in nb.prange(r.shape[0]):
delta=0
for k in range(W.shape[2]):
delta += (r[x,k] - W[0,0,k])**2
for i in range(W.shape[0]):
for j in range(W.shape[1]):
norm = 0
for k in range(W.shape[2]):
norm += (r[x,k] - W[i,j,k])**2
if norm < delta:
delta=norm
ind_i=i
ind_j=j
out[x,0]=ind_i
out[x,1]=ind_j
return out
计时
#small Arrays
N, M = 100, 200
F = 30
S = 50
R = np.random.randint(0, 10, size=(S, F))
W = np.random.randint(-4, 5, size=(N, M, F))
#your function
%timeit nearest_samples_matrix(R,W)
#268 ms ± 2.94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit nearest_samples_matrix_2(R,W)
#5.62 ms ± 22.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit nearest_samples_matrix_3(R,W)
#3.68 ms ± 1.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
#larger arrays
N, M = 1_000, 2_000
F = 50
S = 100
R = np.random.randint(0, 10, size=(S, F))
W = np.random.randint(-4, 5, size=(N, M, F))
#%timeit nearest_samples_matrix_1(R,W)
#too slow
%timeit nearest_samples_matrix_2(R,W)
#2.76 s ± 17.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit nearest_samples_matrix_3(R,W)
#1.42 s ± 402 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
推荐阅读
- sql - 如何使用 apex_200100.www_flow_t_varchar2 作为 int 或 char
- ruby - 如何将哈希转换为字符串哈希?
- flutter - 单击另一个文本字段时,颤振文本字段值消失
- kubernetes - 为什么 kubernetes 默认服务帐号可以完全访问 docker 桌面上的 API?
- html - 在 excel 中使用 VBA 从网页 HTML 上的特定 div 类中获取文本
- python - 如何在 Django ORM 的查询集更新上扩展数组或将值附加到数组字段?[PostgreSQL]
- django - 避免同一序列化程序的不同表示。Django 休息框架
- spring-boot - Kubernetes 需要多长时间才能从 Endpoints 中移除一个终止的 pod?
- react-native - 如何测试正在渲染的图像?- 反应原生 + 玩笑
- javascript - 警告:无法对未安装的组件执行 React 状态更新。这是一个无操作——在 React 组件中来回切换时