首页 > 解决方案 > 使用 pre-SSE4 计算 vector2 double 的 floor 和 ceil

问题描述

这可以通过 sse4.1 内在函数来完成,_mm_floor_pd_mm_ceil_pd 转换为roundpd xmm,xmm,1roundpd xmm,xmm,2

使用SSE/SSE2/SSE3计算的最佳方法是什么?

标签: c++assemblyssesimdintrinsics

解决方案


这是在 SSE4.1 之前的 CPU 上执行 floor/ceil 的代码。请注意,使用“-ffast-math”会破坏它!

#include <cmath>
#include <emmintrin.h>
#include <cstdio> // for printf

#ifdef _MSC_VER
#define __attribute__(P)
#endif

struct vec2d {
    double x;
    double y;
};

static __m128d mm_blendv_pd(const __m128d& a, const __m128d& b, const __m128d& mask) noexcept {
  return _mm_or_pd(_mm_andnot_pd(mask, a), _mm_and_pd(b, mask));
}

__attribute__((optimize("-fno-associative-math")))
vec2d _floor(vec2d v) noexcept {
  __m128d src = _mm_set_pd(v.x, v.y);
  __m128d maxn = _mm_set_pd(4503599627370496.0, 4503599627370496.0);  // pow(2, 52)
  __m128d magic = _mm_set_pd(6755399441055744.0, 6755399441055744.0); // pow(2, 52) + pow(2, 51)
  __m128d msk = _mm_cmpnlt_pd(src, maxn);
  __m128d rounded = _mm_sub_pd(_mm_add_pd(src, magic), magic); //! -ffast-math will break this!
  __m128d maybeone = _mm_and_pd(_mm_cmplt_pd(src, rounded), _mm_set_pd(1.0, 1.0));
  __m128d res = mm_blendv_pd(_mm_sub_pd(rounded, maybeone), src, msk);
  return {_mm_cvtsd_f64(_mm_unpackhi_pd(res, res)), _mm_cvtsd_f64(res)};
}

__attribute__((optimize("-fno-associative-math")))
vec2d _ceil(vec2d v) noexcept {
  __m128d src = _mm_set_pd(v.x, v.y);
  __m128d maxn = _mm_set_pd(4503599627370496.0, 4503599627370496.0);  // pow(2, 52)
  __m128d magic = _mm_set_pd(6755399441055744.0, 6755399441055744.0); // pow(2, 52) + pow(2, 51)
  __m128d msk = _mm_cmpnlt_pd(src, maxn);
  __m128d rounded = _mm_sub_pd(_mm_add_pd(src, magic), magic); //! -ffast-math will break this!
  __m128d maybeone = _mm_and_pd(_mm_cmpnle_pd(src, rounded), _mm_set_pd(1.0, 1.0));
  __m128d res = mm_blendv_pd(_mm_add_pd(rounded, maybeone), src, msk);
  return {_mm_cvtsd_f64(_mm_unpackhi_pd(res, res)), _mm_cvtsd_f64(res)};
}


int main() {
    vec2d v{3.1,4.6};

    vec2d v2 = _floor(v);
    vec2d v3 = _ceil(v);

    printf(" v2: %f %f\n",v2.x,v2.y);
    printf(" v3: %f %f\n",v3.x,v3.y);

    return 0;
}

实时代码

它基于这篇博客文章C++ Compilers and FP Rounding on X86,但是那里的代码有错误。


推荐阅读