c++ - VNNI 指令的 NEON 仿真
问题描述
Cascade Lake Intel CPU 中有新的 AVX-512 VNNI指令,可以加速 CPU 上的神经网络推理。我将它们集成到Simd 库中以加速Synet(我的神经网络推理小框架)并获得显着的性能提升。
事实上,我只使用了一条指令_mm512_dpbusd_epi32
( vpdpbusd
),它允许执行 8 位有符号和无符号整数的乘法,然后将它们累加到 32 位整数累加器中。
为 NEON(ARM 平台)执行模拟优化会很棒。
所以有一个问题:
是否存在任何类似 NEON 指令的模拟vpdpbusd
?如果没有类似物,模拟指令的最佳方法是什么?
下面有一个标量实现(为了更好地理解函数必须做什么):
inline void pdpbusd(int32x4_t& sum, uint8x16_t input, int8x16_t weight)
{
for (size_t i = 0; i < 4; ++i)
for (size_t j = 0; j < 4; ++j)
sum[i] += int32_t(input[i * 4 + j]) * int32_t(weight[i * 4 + j]);
}
解决方案
最直接的实现需要 3 条指令;vmovl.s8
,vmovl.u8
将有符号和无符号 8 位值扩展为 16 位,然后vmlal.s16
, 进行有符号延长 16 位乘法运算,累加到 32 位寄存器中。由于vmlal.s16
仅处理 4 个元素,您需要一秒钟vmlal.s16
来乘以和累加以下 4 个元素 - 所以 4 个元素的 4 条指令。
对于 aarch64 语法,对应的指令是sxtl
,uxtl
和smlal
.
编辑:如果输出元素应该水平聚合,则不能使用融合乘法累加指令vmlal
。然后解决方案是vmovl.s8
and vmovl.u8
,然后是vmul.i16
(对于 8 个输入元素),vpaddl.s16
(水平聚合两个元素),然后是另一个vpadd.i32
以获得水平 4 个元素的总和。所以 5 条指令对应 8 个输入元素,或者 10 条指令对应一个完整的 128 位向量,然后是一个 finalvadd.s32
将最终结果累加到累加器。在 AArch64 上, , 的等价物vpadd.i32
可以addp
处理 128 位向量,因此那里少了一条指令。
如果您使用的是 instrinsics,则实现可能如下所示:
int32x4_t vpdpbusd(int32x4_t sum, uint8x16_t input, int8x16_t weight) {
int16x8_t i1 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input)));
int16x8_t i2 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input)));
int16x8_t w1 = vmovl_s8(vget_low_s8(weight));
int16x8_t w2 = vmovl_s8(vget_high_s8(weight));
int16x8_t p1 = vmulq_s16(i1, w1);
int16x8_t p2 = vmulq_s16(i2, w2);
int32x4_t s1 = vpaddlq_s16(p1);
int32x4_t s2 = vpaddlq_s16(p2);
#if defined(__aarch64__)
int32x4_t s3 = vpaddq_s32(s1, s2);
#else
int32x4_t s3 = vcombine_s32(
vpadd_s32(vget_low_s32(s1), vget_high_s32(s1)),
vpadd_s32(vget_low_s32(s2), vget_high_s32(s2))
);
#endif
sum = vaddq_s32(sum, s3);
return sum;
}
推荐阅读
- autohotkey - 使用鼠标位置激活键盘键的代码
- ios - Flutter 应用程序作为从 IOS 上的浏览器接收 pdf 的选项
- python - 如何用变量更新sql表
- amazon-s3 - Cloudwatch 突然上传到 s3 的警报
- ansible - 从 git repo (Ansible for Nsxt) 安装和使用 Ansible 模块
- json - Elasticsearch 查询以获取多个属性的值
- string - 如何在lua中删除多行字符串
- android - Mailchimp 和 Gmail:文本颜色失真
- python - windows上部分导入pyvirtualcam、python 3.8
- data-visualization - 水平箱线图Stata