python - python:计算向量到矩阵每一行的欧几里得距离的最快方法?
问题描述
考虑一下这个 python 代码,我在其中尝试计算向量到矩阵每一行的欧式距离。与我使用 Tullio.jl 找到的最佳 Julia 版本相比,它非常慢。
python 版本需要30s而 Julia 版本只需要75ms。
我确信我在 Python 方面做得不是最好的。有更快的解决方案吗?欢迎使用 Numba 和 numpy 解决方案。
import numpy as np
# generate
a = np.random.rand(4000000, 128)
b = np.random.rand(128)
print(a.shape)
print(b.shape)
def lin_norm_ever(a, b):
return np.apply_along_axis(lambda x: np.linalg.norm(x - b), 1, a)
import time
t = time.time()
res = lin_norm_ever(a, b)
print(res.shape)
elapsed = time.time() - t
print(elapsed)
朱莉娅版本
using Tullio
function comp_tullio(a, c)
dist = zeros(Float32, size(a, 2))
@tullio dist[i] = (c[j] - a[j,i])^2
dist
end
@time comp_tullio(a, c)
@benchmark comp_tullio(a, c) # 75ms on my computer
解决方案
我将在此示例中使用 Numba 以获得最佳性能。我还添加了来自 Divakars 链接答案的 2 种方法以进行比较。
代码
import numpy as np
import numba as nb
from scipy.spatial.distance import cdist
@nb.njit(fastmath=True,parallel=True,cache=True)
def dist_1(mat,vec):
res=np.empty(mat.shape[0],dtype=mat.dtype)
for i in nb.prange(mat.shape[0]):
acc=0
for j in range(mat.shape[1]):
acc+=(mat[i,j]-vec[j])**2
res[i]=np.sqrt(acc)
return res
#from https://stackoverflow.com/a/52364284/4045774
def dist_2(mat,vec):
return cdist(mat, np.atleast_2d(vec)).ravel()
#from https://stackoverflow.com/a/52364284/4045774
def dist_3(mat,vec):
M = mat.dot(vec)
d = np.einsum('ij,ij->i',mat,mat) + np.inner(vec,vec) -2*M
return np.sqrt(d)
计时
#Float64
a = np.random.rand(4000000, 128)
b = np.random.rand(128)
%timeit dist_1(a,b)
#122 ms ± 3.86 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit dist_2(a,b)
#484 ms ± 3.02 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit dist_3(a,b)
#432 ms ± 14.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
#Float32
a = np.random.rand(4000000, 128).astype(np.float32)
b = np.random.rand(128).astype(np.float32)
%timeit dist_1(a,b)
#68.6 ms ± 414 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit dist_2(a,b)
#2.2 s ± 32.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
#looks like there is a costly type-casting to float64
%timeit dist_3(a,b)
#228 ms ± 8.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
推荐阅读
- excel - 如何将 Line2 中的每个单元格与 Line1 进行比较
- c++ - 期货和线程,运算符 +
- mysql - MYSQL 中的 Alter Column 语句有什么问题?
- git - 如何从 master 的另一个功能分支 rebase 或合并分支
- javascript - 为什么在使用 jquery 的 Firefox 中放大和缩小功能无法正常工作?
- rspec - Capybara wit RSpec 无法找到未禁用的字段“Nombre del cliente”
- node.js - 尝试在猫鼬中使用 $near 时获取并清空数组
- bash - 如何计算Makefile中文件的行数?
- python - Python从HTML页面读取完整表格
- python - 具有稀疏矩阵的 sklearn.svm.SVR 会产生错误