首页 > 解决方案 > 尽可能快地比较 (a + sqrt(b)) 形式的两个值?

问题描述

作为我正在编写的程序的一部分,我需要以a + sqrt(b)whereab是无符号整数的形式比较两个值。由于这是一个紧密循环的一部分,我希望这个比较尽可能快地运行。(如果重要的话,我在 x86-64 机器上运行代码,无符号整数不大于 10^6。另外,我知道一个事实a1<a2。)

作为一个独立的功能,这是我想要优化的。我的数字是足够小的整数double(甚至float)可以准确地表示它们,但结果中的舍入误差sqrt不能改变结果。

// known pre-condition: a1 < a2  in case that helps
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    return a1+sqrt(b1) < a2+sqrt(b2);  // computed mathematically exactly
}

测试用例is_smaller(900000, 1000000, 900001, 998002)应该返回 true,但正如 @wim 计算它的评论中所示,它sqrtf()会返回 false。所以会(int)sqrt()截断回整数。

a1+sqrt(b1) = 90100a2+sqrt(b2) = 901000.00050050037512481206。最接近的浮点数正好是 90100。


由于sqrt()即使在现代 x86-64 上,当完全内联为指令时,该函数通常也相当昂贵sqrtsd,因此我试图尽可能避免调用sqrt()

通过平方去除 sqrt 还可以通过使所有计算精确来避免任何舍入错误的危险。

相反,如果函数是这样的......

bool is_smaller(unsigned a1, unsigned b1, unsigned x) {
    return a1+sqrt(b1) < x;
}

...那我可以做return x-a1>=0 && static_cast<uint64_t>(x-a1)*(x-a1)>b1;

但是现在因为有两个sqrt(...)项,我不能做同样的代数操作。

通过使用以下公式,我可以将值平方两次

      a1 + sqrt(b1) = a2 + sqrt(b2)
<==>  a1 - a2 = sqrt(b2) - sqrt(b1)
<==>  (a1 - a2) * (a1 - a2) = b1 + b2 - 2 * sqrt(b1) * sqrt(b2)
<==>  (a1 - a2) * (a1 - a2) = b1 + b2 - 2 * sqrt(b1 * b2)
<==>  (a1 - a2) * (a1 - a2) - (b1 + b2) = - 2 * sqrt(b1 * b2)
<==>  ((b1 + b2) - (a1 - a2) * (a1 - a2)) / 2 = sqrt(b1 * b2)
<==>  ((b1 + b2) - (a1 - a2) * (a1 - a2)) * ((b1 + b2) - (a1 - a2) * (a1 - a2)) / 4 = b1 * b2

无符号除以 4 很便宜,因为它只是一个位移,但由于我将数字平方两次,我将需要使用 128 位整数,并且我需要引入一些>=0检查(因为我正在比较不等式而不是等式)。

感觉可能有一种方法可以更快地做到这一点,通过对这个问题应用更好的代数。有没有办法更快地做到这一点?

标签: c++optimizationalgebramicro-optimizationsqrt

解决方案


这是一个没有 的版本sqrt,尽管我不确定它是否比只有一个的版本更快sqrt(它可能取决于值的分布)。

这是数学(如何删除两个 sqrts):

ad = a2-a1
bd = b2-b1

a1+sqrt(b1) < a2+sqrt(b2)              // subtract a1
   sqrt(b1) < ad+sqrt(b2)              // square it
        b1  < ad^2+2*ad*sqrt(b2)+b2    // arrange
   ad^2+bd  > -2*ad*sqrt(b2)

在这里,右边总是负数。如果左边是正数,那么我们必须返回 true。

如果左边是负数,那么我们可以平方不等式:

ad^4+bd^2+2*bd*ad^2 < 4*ad^2*b2

这里要注意的关键是 if a2>=a1+1000, thenis_smaller总是返回true(因为 的最大值sqrt(b1)是 1000)。如果a2<=a1+1000, thenad是一个小数,所以ad^4总是适合 64 位(不需要 128 位算术)。这是代码:

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    int ad = a2 - a1;
    if (ad>1000) {
        return true;
    }

    int bd = b2 - b1;
    if (ad*ad+bd>0) {
        return true;
    }

    int ad2 = ad*ad;

    return (long long int)ad2*ad2 + (long long int)bd*bd + 2ll*bd*ad2 < 4ll*ad2*b2;
}

编辑:正如 Peter Cordes 所注意到的,第一个if不是必需的,因为第二个 if 处理它,所以代码变得更小更快:

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    int ad = a2 - a1;
    int bd = b2 - b1;
    if ((long long int)ad*ad+bd>0) {
        return true;
    }

    int ad2 = ad*ad;
    return (long long int)ad2*ad2 + (long long int)bd*bd + 2ll*bd*ad2 < 4ll*ad2*b2;
}

推荐阅读