c++ - 我获取 int 数组点积的内在函数比普通代码慢,我做错了什么?
问题描述
我正在尝试了解内在以及如何正确利用和优化它,我决定实现一个函数来获取两个数组的点积作为学习的起点。
我创建了两个函数来获取整数数组的点积,int
一个以正常方式编码,您循环遍历两个数组的每个元素,然后与每个元素执行乘法,然后将结果乘积相加/累加/求和以获得点积。
另一个使用内在的方式,我对每个数组的四个元素执行内在操作,我使用它们中的每一个相乘_mm_mullo_epi32
,然后使用 2 个水平加法 _mm_hadd_epi32
来获得当前 4 个元素的总和,然后我将它加到dot_product,然后继续下一个四个元素,然后重复直到达到计算的 limit vec_loop
,然后我使用正常的方式计算其他剩余元素以避免计算出数组的内存,然后我比较两者的性能。
具有两种点积函数的头文件:
// main.hpp
#ifndef main_hpp
#define main_hpp
#include <iostream>
#include <immintrin.h>
template<typename T>
T scalar_dot(T* a, T* b, size_t len){
T dot_product = 0;
for(size_t i=0; i<len; ++i) dot_product += a[i]*b[i];
return dot_product;
}
int sse_int_dot(int* a, int* b, size_t len){
size_t vec_loop = len/4;
size_t non_vec = len%4;
size_t start_non_vec_i = len-non_vec;
int dot_prod = 0;
for(size_t i=0; i<vec_loop; ++i)
{
__m128i va = _mm_loadu_si128((__m128i*)(a+(i*4)));
__m128i vb = _mm_loadu_si128((__m128i*)(b+(i*4)));
va = _mm_mullo_epi32(va,vb);
va = _mm_hadd_epi32(va,va);
va = _mm_hadd_epi32(va,va);
dot_prod += _mm_cvtsi128_si32(va);
}
for(size_t i=start_non_vec_i; i<len; ++i) dot_prod += a[i]*b[i];
return dot_prod;
}
#endif
cpp 代码来测量每个函数所花费的时间
// main.cpp
#include <iostream>
#include <chrono>
#include <random>
#include "main.hpp"
int main()
{
// generate random integers
unsigned seed = std::chrono::steady_clock::now().time_since_epoch().count();
std::mt19937_64 rand_engine(seed);
std::mt19937_64 rand_engine2(seed/2);
std::uniform_int_distribution<int> random_number(0,9);
size_t LEN = 10000000;
int* a = new int[LEN];
int* b = new int[LEN];
for(size_t i=0; i<LEN; ++i)
{
a[i] = random_number(rand_engine);
b[i] = random_number(rand_engine2);
}
#ifdef SCALAR
int dot1 = 0;
#endif
#ifdef VECTOR
int dot2 = 0;
#endif
// timing
auto start = std::chrono::high_resolution_clock::now();
#ifdef SCALAR
dot1 = scalar_dot(a,b,LEN);
#endif
#ifdef VECTOR
dot2 = sse_int_dot(a,b,LEN);
#endif
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(end-start);
std::cout<<"proccess taken "<<duration.count()<<" nanoseconds\n";
#ifdef SCALAR
std::cout<<"\nScalar : Dot product = "<<dot1<<"\n";
#endif
#ifdef VECTOR
std::cout<<"\nVector : Dot product = "<<dot2<<"\n";
#endif
return 0;
}
汇编:
- 内在版本:
g++ main.cpp -DVECTOR -msse4.1 -o main.o
- 普通版:
g++ main.cpp -DSCALAR -msse4.1 -o main.o
我的机器:
- 架构:x86_64
- CPU : 1
- CPU 内核:4
- 每个内核的线程数:1
- 型号名称:Intel(R) Pentium(R) CPU N3700 @ 1.60GHz
- L1d 缓存:96 KiB
- L1i 缓存:128 KiB
- 二级缓存:2 MiB
- 一些标志:sse、sse2、sse4_1、sse4_2
在main.cpp
有10000000个元素的int
数组中,当我在我的机器上编译上面的代码时,内在函数似乎比普通版本运行得慢,大多数时候,内在函数需要大约97529675 nanoseconds
,有时甚至更长,而普通代码只需要87568313 nanoseconds
,在这里我认为如果优化标志关闭,我的内在函数应该运行得更快,但事实证明它确实有点慢。
所以我的问题是:
- 为什么我的内在函数运行速度较慢?(难道我做错了什么?)
- 如何纠正我的内在实现,正确的方法是什么?
- 即使优化标志关闭,编译器是否会自动矢量化/展开正常代码
- 鉴于我的机器规格,获得点积的最快方法是什么?
我希望有人可以帮助,谢谢
解决方案
因此,根据@Peter Cordes、@Qubit 和@j6t的建议,我对代码进行了一些调整,现在我只在循环内进行乘法运算,然后将水平加法移到了循环外……它设法提高了内部版本从 around 97529675 nanoseconds
,到 around56444187 nanoseconds
比我以前的实现要快得多,具有相同的编译标志和10000000个 int 数组元素。
这是 main.hpp 中的新函数
int _sse_int_dot(int* a, int* b, size_t len){
size_t vec_loop = len/4;
size_t non_vec = len%4;
size_t start_non_vec_i = len-non_vec;
int dot_product;
__m128i vdot_product = _mm_set1_epi32(0);
for(size_t i=0; i<vec_loop; ++i)
{
__m128i va = _mm_loadu_si128((__m128i*)(a+(i*4)));
__m128i vb = _mm_loadu_si128((__m128i*)(b+(i*4)));
__m128i vc = _mm_mullo_epi32(va,vb);
vdot_product = _mm_add_epi32(vdot_product,vc);
}
vdot_product = _mm_hadd_epi32(vdot_product,vdot_product);
vdot_product = _mm_hadd_epi32(vdot_product,vdot_product);
dot_product = _mm_cvtsi128_si32(vdot_product);
for(size_t i=start_non_vec_i; i<len; ++i) dot_product += a[i]*b[i];
return dot_product;
}
如果此代码还有更多需要改进的地方,请指出,现在我将把它留在这里作为答案。
推荐阅读
- laravel - 如何在 Laravel nova 中的索引上显示降价字段?
- javascript - 为什么我没有通过 id 获得一件商品?
- visual-c++ - CString::Replace not working with values from registry
- html - 我的变量在我的 JS 函数 JSX 中没有递增
- javascript - 当我想要 dockerize 我的 nextjs 项目时出现问题
- javascript - VueJS路由器中的`path`和`fullPath`有什么区别?
- glibc - LFS 8.3 Glibc-2.28 使检查失败'stdlib/test-bz22786'
- oracle - 创建 Oracle 表时出错,无效标识符
- pydicom - Orthanc匿名dcm文件,有没有办法直接转换dcm文件而不先生成匿名文件
- ios - WKWebView 没有将内容滚动到映射的 HTML