c - 4 个 uint16_t 的快速模 12 算法打包在一个 uint64_t 中
问题描述
考虑以下联合:
union Uint16Vect {
uint16_t _comps[4];
uint64_t _all;
};
是否有一种快速算法来确定每个分量是否等于 1 模 12?
一个天真的代码序列是:
Uint16Vect F(const Uint16Vect a) {
Uint16Vect r;
for (int8_t k = 0; k < 4; k++) {
r._comps[k] = (a._comps[k] % 12 == 1) ? 1 : 0;
}
return r;
}
解决方案
编译器会将除以常数优化为乘以倒数或乘法逆。例如x/12
将优化为x*43691 >> 19
bool h(uint16_t x)
{
return x % 12 == 1;
}
h(unsigned short):
movzx eax, di
imul eax, eax, 43691 ; = 0xFFFF*8/12 + 1
shr eax, 19
lea eax, [rax+rax*2]
sal eax, 2
sub edi, eax
cmp di, 1
sete al
ret
因为 SSE/AVX 中有乘法指令,所以很容易向量化。此外,x = (x % 12 == 1) ? 1 : 0;
可以简化为x = (x % 12 == 1)
然后转换为x = (x - 1) % 12 == 0
避免从常量表中加载值 1 进行比较。您可以使用向量扩展,以便 gcc 自动为您生成代码
typedef uint16_t ymm32x2 __attribute__((vector_size(32)));
ymm32x2 mod12(ymm32x2 x)
{
return !!((x - 1) % 12);
}
下面是gcc 的输出
mod12(unsigned short __vector(16)):
vpcmpeqd ymm3, ymm3, ymm3 ; ymm3 = -1
vpaddw ymm0, ymm0, ymm3
vpmulhuw ymm1, ymm0, YMMWORD PTR .LC0[rip] ; multiply with 43691
vpsrlw ymm2, ymm1, 3
vpsllw ymm1, ymm2, 1
vpaddw ymm1, ymm1, ymm2
vpsllw ymm1, ymm1, 2
vpcmpeqw ymm0, ymm0, ymm1
vpandn ymm0, ymm0, ymm3
ret
Clang 和 ICC 不支持!!
向量类型,因此您需要更改为(x - 1) % 12 == 0
. 不幸的是,编译器似乎不支持 __attribute__((vector_size(8))
发出 MMX 指令。但是现在你无论如何都应该使用 SSE 或 AVX
正如您在上面的同一个 Godbolt 链接中看到的那样,输出x % 12 == 1
更短,但您需要一个包含 1 的表来比较,这可能更好也可能不好。编译器也可能无法完全优化为手写代码,因此您可以尝试使用内部函数手动对代码进行矢量化。检查哪一个在您的情况下工作得更快
更好的方法是((x * 43691) & 0x7ffff) < 43691
,或者x * 357913942 < 357913942
如nwellnhof 的回答中提到的,它也应该很容易矢量化
或者,对于像这样的小输入范围,您可以使用查找表。基础版需要一个65536元素的数组
#define S1(x) ((x) + 0) % 12 == 1, ((x) + 1) % 12 == 1, ((x) + 2) % 12 == 1, ((x) + 3) % 12 == 1, \
((x) + 4) % 12 == 1, ((x) + 4) % 12 == 1, ((x) + 6) % 12 == 1, ((x) + 7) % 12 == 1
#define S2(x) S1((x + 0)*8), S1((x + 1)*8), S1((x + 2)*8), S1((x + 3)*8), \
S1((x + 4)*8), S1((x + 4)*8), S1((x + 6)*8), S1((x + 7)*8)
#define S3(x) S2((x + 0)*8), S2((x + 1)*8), S2((x + 2)*8), S2((x + 3)*8), \
S2((x + 4)*8), S2((x + 4)*8), S2((x + 6)*8), S2((x + 7)*8)
#define S4(x) S3((x + 0)*8), S3((x + 1)*8), S3((x + 2)*8), S3((x + 3)*8), \
S3((x + 4)*8), S3((x + 4)*8), S3((x + 6)*8), S3((x + 7)*8)
bool mod12e1[65536] = {
S4(0U), S4(8U), S4(16U), S4(24U), S4(32U), S4(40U), S4(48U), S4(56U)
}
要使用只需替换x % 12 == 1
为mod12e1[x]
. 这当然可以矢量化
但由于结果只有 1 或 0,你也可以使用65536 位数组将大小减小到只有 8KB
您还可以通过 4 和 3 的整除性来检查 12 的整除性。被 4 整除显然是微不足道的。3的整除可以通过多种方式计算
一个是计算奇数之和与偶数之和之间的差异,如גלעד ברקן 的答案,并检查它是否可被 3 整除
或者您可以检查以 2 为底的数字之和是否为2k(例如以 4、16、64 为底...),看看它是否可以被 3 整除。
这是有效的,因为在
b
检查任何除数 n 的除数的基础上b - 1
,只需检查数字的总和是否可被 n 整除。这是它的一个实现void modulo12equals1(uint16_t d[], uint32_t size) { for (uint32_t i = 0; i < size; i++) { uint16_t x = d[i] - 1; bool divisibleBy4 = x % 4 == 0; x = (x >> 8) + (x & 0x00ff); // max 1FE x = (x >> 4) + (x & 0x000f); // max 2D bool divisibleBy3 = !!((01111111111111111111111ULL >> x) & 1); d[i] = divisibleBy3 && divisibleBy4; } }
Roland Illig被 3 整除的学分
由于自动矢量化的汇编输出太长,您可以在Godbolt 链接上查看
也可以看看
推荐阅读
- node.js - how to make for loop in FindById and then send response in mongoose?
- bash - 使用 wget 将 bash 函数转换为 powershell
- java - 如何使用 Spring WebFlux 实现自定义异常处理程序
- javascript - 需要从函数传递参数并将其传递到 JSON 有效负载值中
- javascript - 添加和删除人员的循环算法
- javascript - How to save an output in console as a variable in react?
- python - 在 docker 容器中运行 FastAPI
- python - timedelta 到欧洲中部时间下午 12 点
- android - 使用 socket.io Android 发送 jsonArray 时出现问题
- laravel - 如何为背包管理员 crud 功能编写适当的测试?