swift - 加速框架:计算向量的成对距离
问题描述
我有 N 个浮点向量,我想计算它们之间的成对归一化 L2 距离。对于向量 u 和 v,归一化 L2 距离定义为:|| u / ||u||_2 - v / ||v||_2 ||_2, where || ... ||_2 is the L2 norm (i.e. square root of the sum of squares)
我写了一个代表成对距离矩阵的类:
class PairwiseDistanceMatrix {
private let count: Int
private let buffer: [Float]
private let bufferPointer: UnsafeBufferPointer<Float>
let rows: Int
let columns: Int
init(with vectors: [[Float]]) {
count = vectors.count
let len = vDSP_Length(vectors[0].count)
var norm: Float = .nan
var divRes = [Float](repeating: .nan, count: Int(len))
// Normalizing the vectors:
let norms = Array(0..<count).map { (index) -> [Float] in
vDSP_svesq(vectors[index], 1, &norm, len)
norm = norm.squareRoot()
vDSP_vsdiv(vectors[index], 1, &norm, &divRes, 1, len)
return divRes
}
var mutableBuffer = [Float](repeating: 0.0, count: count * count)
let mutableBufferPointer = UnsafeMutableBufferPointer<Float>.init(start: &mutableBuffer, count: mutableBuffer.count)
// Computing the distances between the normalized vectors
var distancesq: Float = .nan
for i in 0..<(count - 1) {
for j in i..<count {
let index = i * count + j
let symetricIndex = j * count + i
vDSP_distancesq(norms[i], 1, norms[j], 1, &distancesq, len)
mutableBufferPointer[index] = distancesq.squareRoot()
mutableBufferPointer[symetricIndex] = mutableBufferPointer[index]
}
}
buffer = mutableBuffer
bufferPointer = UnsafeBufferPointer<Float>.init(mutableBufferPointer)
rows = count
columns = count
}
func elementAt(row: Int, column: Int) -> Float {
return bufferPointer[row * count + column]
}
}
目前,对于 5000 个向量,此代码在 iPhone X 上运行时间约为 600 毫秒。正如预期的那样,其中大部分采用嵌套循环。(向量归一化需要不到 2 毫秒)。我很确定这段代码可以优化。欢迎任何想法或方向。
解决方案
让我们忽略标准化,因为这不是问题,而且您已经有了一个相当有效的解决方案。
假设向量已经归一化,那么,我们要填充第 ij 项为 ||v_i - v_j|| 的矩阵。这可以比您当前的蛮力方法更有效地完成。请注意
||v_i - v_j||^2 = (v_i - v_j)*(v_i - v_j)
= v_i * v_i - 2 v_i * v_j + v_j * v_j
= ||v_i||^2 + ||v_j||^2 - 2 v_i * v_j
= 2(1 - v_i * v_j)
所以我们需要计算成对点积的矩阵,然后将 f(x) = sqrt(2 - 2*x) 应用于矩阵的每个条目。
如果我们将向量打包成矩阵 A 的行并计算 A*transpose(A),那么成对点积的矩阵正是我们得到的,这可以使用SSYRK
BLAS 提供的函数来完成。
推荐阅读
- visual-studio - VS创建的Nuget包不更新版本
- java - 在 Docker 容器(OS-Mac M1)中安装 Keyclock 时出现错误
- php - 必需的属性在codeigniter php中不起作用
- flutter - 颤振错误:(动态)=> Null 不是(字符串,动态)的子类型
- java - gradle 项目(由 IntelliJ IDEA 2021 导入)无法从 webapp 文件夹中识别 jar
- bash - 使用 bash 删除 CSF“请勿删除”IP 条目
- javascript - 从 iOS 应用程序调用函数时,Firebase HTTP 云函数给出 401 错误
- arangodb - ArangoDB 为同一文档返回不同的修订版本
- http - HTTP 标头值允许的字符
- javascript - 获取 TypeError:无法在 Laravel 中使用 Vue 读取未定义的属性