首页 > 解决方案 > 4 个 uint16_t 的快速模 12 算法打包在一个 uint64_t 中

问题描述

考虑以下联合:

union Uint16Vect {
    uint16_t _comps[4];
    uint64_t _all;
};

是否有一种快速算法来确定每个分量是否等于 1 模 12?

一个天真的代码序列是:

Uint16Vect F(const Uint16Vect a) {
    Uint16Vect r;
    for (int8_t k = 0; k < 4; k++) {
        r._comps[k] = (a._comps[k] % 12 == 1) ? 1 : 0;
    }
    return r;
}

标签: calgorithmvectorizationmoduloavx2

解决方案


编译器会将除以常数优化为乘以倒数或乘法逆。例如x/12将优化为x*43691 >> 19

bool h(uint16_t x)
{
    return x % 12 == 1;
}
h(unsigned short):
        movzx   eax, di
        imul    eax, eax, 43691 ; = 0xFFFF*8/12 + 1
        shr     eax, 19
        lea     eax, [rax+rax*2]
        sal     eax, 2
        sub     edi, eax
        cmp     di, 1
        sete    al
        ret

因为 SSE/AVX 中有乘法指令,所以很容易向量化。此外,x = (x % 12 == 1) ? 1 : 0;可以简化为x = (x % 12 == 1)然后转换为x = (x - 1) % 12 == 0避免从常量表中加载值 1 进行比较。您可以使用向量扩展,以便 gcc 自动为您生成代码

typedef uint16_t ymm32x2 __attribute__((vector_size(32)));
ymm32x2 mod12(ymm32x2 x)
{
    return !!((x - 1) % 12);
}

下面是gcc 的输出

mod12(unsigned short __vector(16)):
        vpcmpeqd    ymm3, ymm3, ymm3  ; ymm3 = -1
        vpaddw      ymm0, ymm0, ymm3
        vpmulhuw    ymm1, ymm0, YMMWORD PTR .LC0[rip] ; multiply with 43691
        vpsrlw      ymm2, ymm1, 3
        vpsllw      ymm1, ymm2, 1
        vpaddw      ymm1, ymm1, ymm2
        vpsllw      ymm1, ymm1, 2
        vpcmpeqw    ymm0, ymm0, ymm1
        vpandn      ymm0, ymm0, ymm3
        ret

Clang 和 ICC 不支持!!向量类型,因此您需要更改为(x - 1) % 12 == 0. 不幸的是,编译器似乎不支持 __attribute__((vector_size(8))发出 MMX 指令。但是现在你无论如何都应该使用 SSE 或 AVX

正如您在上面的同一个 Godbolt 链接中看到的那样,输出x % 12 == 1更短,但您需要一个包含 1 的表来比较,这可能更好也可能不好。编译器也可能无法完全优化为手写代码,因此您可以尝试使用内部函数手动对代码进行矢量化。检查哪一个在您的情况下工作得更快

更好的方法是((x * 43691) & 0x7ffff) < 43691,或者x * 357913942 < 357913942nwellnhof 的回答中提到的,它也应该很容易矢量化


或者,对于像这样的小输入范围,您可以使用查找表。基础版需要一个65536元素的数组

#define S1(x) ((x) + 0) % 12 == 1, ((x) + 1) % 12 == 1, ((x) + 2) % 12 == 1, ((x) + 3) % 12 == 1, \
              ((x) + 4) % 12 == 1, ((x) + 4) % 12 == 1, ((x) + 6) % 12 == 1, ((x) + 7) % 12 == 1
#define S2(x) S1((x + 0)*8), S1((x + 1)*8), S1((x + 2)*8), S1((x + 3)*8), \
              S1((x + 4)*8), S1((x + 4)*8), S1((x + 6)*8), S1((x + 7)*8)
#define S3(x) S2((x + 0)*8), S2((x + 1)*8), S2((x + 2)*8), S2((x + 3)*8), \
              S2((x + 4)*8), S2((x + 4)*8), S2((x + 6)*8), S2((x + 7)*8)
#define S4(x) S3((x + 0)*8), S3((x + 1)*8), S3((x + 2)*8), S3((x + 3)*8), \
              S3((x + 4)*8), S3((x + 4)*8), S3((x + 6)*8), S3((x + 7)*8)

bool mod12e1[65536] = {
    S4(0U), S4(8U), S4(16U), S4(24U), S4(32U), S4(40U), S4(48U), S4(56U)
}

要使用只需替换x % 12 == 1mod12e1[x]. 这当然可以矢量化

但由于结果只有 1 或 0,你也可以使用65536 位数组将大小减小到只有 8KB


您还可以通过 4 和 3 的整除性来检查 12 的整除性。被 4 整除显然是微不足道的。3的整除可以通过多种方式计算

  • 一个是计算奇数之和与偶数之和之间的差异,גלעד ברקן 的答案,并检查它是否可被 3 整除

  • 或者您可以检查以 2 为底的数字之和是否为2k(例如以 4、16、64 为底...),看看它是否可以被 3 整除。

    这是有效的,因为在b检查任何除数 n 的除数的基础上b - 1,只需检查数字的总和是否可被 n 整除。这是它的一个实现

      void modulo12equals1(uint16_t d[], uint32_t size) {
          for (uint32_t i = 0; i < size; i++)
          {
              uint16_t x = d[i] - 1;
              bool divisibleBy4 = x % 4 == 0;
              x = (x >> 8) + (x & 0x00ff); // max 1FE
              x = (x >> 4) + (x & 0x000f); // max 2D
              bool divisibleBy3 = !!((01111111111111111111111ULL >> x) & 1);
              d[i] = divisibleBy3 && divisibleBy4;
          }
      }
    

Roland Illig被 3 整除的学分

由于自动矢量化的汇编输出太长,您可以在Godbolt 链接上查看

也可以看看


推荐阅读