matrix - cublasSgemm 行主要乘法
问题描述
我正在尝试使用 cublasSgemm 将两个以行优先顺序存储的非方阵相乘。我知道这个函数有一个参数,您可以在其中指定是否要转置矩阵(CUBLAS_OP_T),但结果以列优先顺序存储,我也需要以行优先顺序。
此外,我拥有的代码无法将非方阵与参数 CUBLAS_OP_T 相乘。CUBLAS_OP_N 只能是正方形或非正方形。
此外,我知道可以选择以列顺序声明矩阵
define IDX2C(i,j,ld) (((j)*(ld))+(i))
但这不是一个选项,因为我必须使用的矩阵将在其他程序中设置。
我想互联网上有很多信息,但我无法找到我的问题的任何答案。
我的代码如下:
int m = 2;
int k = 3;
int n = 4;
int print = 1;
cudaError_t cudaStat; // cudaMalloc status
cublasStatus_t stat; // CUBLAS functions status
cublasHandle_t handle; // CUBLAS context
int i,j;
float *a, *b,*c;
//malloc for a,b,c...
// define a mxk matrix a row by row
int ind =11;
for(j=0;j<m*k;j++){
a[j]=(float)ind++;
}
// define a kxn matrix b column by column
ind =11;
for(j=0;j<k*n;j++){
b[j]=(float)ind++;
}
// DEVICE
float *d_a, *d_b, *d_c;
//cudaMalloc for d_a, d_b, d_c...
stat = cublasCreate(&handle); // initialize CUBLAS context
stat = cublasSetMatrix(m,k, sizeof(*a), a, m, d_a, m);
stat = cublasSetMatrix(k,n, sizeof(*b), b, k, d_b, k);
float al =1.0f;
float bet =0.0f;
stat=cublasSgemm(handle,CUBLAS_OP_T,CUBLAS_OP_T,m,n,k,&al,d_a,m,d_b,k,&bet,d_c,m);
stat = cublasGetMatrix (m,n, sizeof (*c) ,d_c ,m,c,m); // cp d_c - >c
if(print == 1) {
printf ("\nc after Sgemm :\n");
for(i=0;i<m*n;i++){
printf ("%f ",c[i]); // print c after Sgemm
}
}
cudaFree (d_a);
cudaFree (d_b);
cudaFree (d_c);
cublasDestroy (handle); // destroy CUBLAS context
free (a);
free (b);
free (c);
return EXIT_SUCCESS ;
}
输出是乘以 A * B 的转置,即:(A * B)T。
但我想要的是 C = A * B 以行为主的顺序。
我希望有一个人可以帮助我。
解决方案
正如您所说,cuBLAS 将矩阵解释为列优先排序,因此当您执行时 cublasSgemm(handle,CUBLAS_OP_T,CUBLAS_OP_T,m,n,k,&al,d_a,m,d_b,k,&bet,d_c,m)
,您正确地转置了每个输入(以行优先形式创建)以准备列优先解释。问题是 cuBLAS 还以列优先顺序转储结果。
我们将欺骗 cuBLAS 进行计算,它将以列优先顺序输出,因此看起来就像我们以行优先顺序巧妙地解释它时的样子。因此,我们不计算 AB = C,而是计算= 。幸运的是,我们已经通过按行优先顺序创建 A 和 B 的动作获得了,所以我们可以简单地绕过 CUBLAS_OP_N 的转置。因此将行更改为cublasSgemm(handle,CUBLAS_OP_N,CUBLAS_OP_N,n,m,k,&al,d_b,n,d_a,k,&bet,d_c,n)
.
您提供的示例代码应该计算
在运行更新后的调用后cublasSgemm
,我们得到:
c after Sgemm :
548.000000 584.000000 620.000000 656.000000 683.000000 728.000000 773.000000 818.000000
推荐阅读
- php - Docker 连接 nginx/php/mysql 的问题
- python - 输入 0 与层 lstm_16 不兼容:预期 ndim=3,发现 ndim=2?
- authentication - 启用对非公共 Google Cloud Function 的令牌访问
- r - 有条件连接的 R 函数吗?
- python - 有没有办法使用运行函数的结果从 vaderSentiment 包创建一个新的 DataFrame
- python - 如何停止 multiprocessing.Pool.map 异常
- django - 如何在 django post_save 上使用 update_or_create?
- android - Android Emulator 不支持人脸检测
- javascript - 在 Javascript 中遍历分页 API 结果
- c++ - 通过构造函数完成时未填充向量