cuda - cublas batched gemm throw not supported error with large batch size
问题描述
我正在调用 cublasGemmStriderdBatchedEx() API。我让第一个矩阵大步前进,第二个矩阵固定。该程序适用于小输入,但在大批量时会引发 CUBLAS_STATUS_NOT_SUPPORTED 错误。
根据cublas 文档,这意味着不支持数据类型或算法。我看不出增加批量大小如何改变数据类型。我使用默认的启发式 GEMM 算法。
我用 CUDA9.2 编译了代码并在 GTX 1050 卡上运行它。
代码:
#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include "nvidia_helper/checkCudaErrors.h"
int UPPER_BOUND = 4096;
int main() {
half* F4_re;
half* X_split;
float* result1;
int M = 16;
int B = 256*64;
checkCudaErrors(cudaMallocManaged((void **) &F4_re, 4 * 4 * sizeof(half)));
checkCudaErrors(cudaMallocManaged((void **) &X_split, M * 4 * B * 4 * sizeof(half)));
checkCudaErrors(cudaMallocManaged((void **) &result1, M * 4 * B * 4 * sizeof(float)));
F4_re[0] = 1.0f;
F4_re[1] = 1.0f;
F4_re[2] = 1.0f;
F4_re[3] = 1.0f;
F4_re[4] = 1.0f;
F4_re[5] = 0.0f;
F4_re[6] =-1.0f;
F4_re[7] = 0.0f;
F4_re[8] = 1.0f;
F4_re[9] =-1.0f;
F4_re[10] = 1.0f;
F4_re[11] =-1.0f;
F4_re[12] = 1.0f;
F4_re[13] = 0.0f;
F4_re[14] =-1.0f;
F4_re[15] = 0.0f;
srand(time(NULL));
for (int i = 0; i < M * 4 * B * 4; i++) {
X_split[i] = (float)rand() / (float)(RAND_MAX) * 2 * UPPER_BOUND - UPPER_BOUND;
}
cublasStatus_t status;
cublasHandle_t handle;
float alpha = 1.0f, beta = 0.0f;
status = cublasCreate(&handle);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! CUBLAS initialization error\n");
exit(1);
}
status = cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH); // allow Tensor Core
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! CUBLAS setting math mode error\n");
exit(1);
}
long long int stride = M * 4;
status = cublasGemmStridedBatchedEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, M, 4, 4, &alpha, X_split,
CUDA_R_16F, M, stride, F4_re, CUDA_R_16F, 4, 0, &beta, result1, CUDA_R_32F, M, stride, B * 4, CUDA_R_32F, CUBLAS_GEMM_DEFAULT);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! CUBLAS kernel execution error: %d .\n", status);
exit(1);
}
status = cublasDestroy(handle);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! shutdown error (A)\n");
exit(1);
}
checkCudaErrors(cudaFree(F4_re));
checkCudaErrors(cudaFree(X_split));
checkCudaErrors(cudaFree(result1));
return 0;
}
解决方案
此问题应在刚刚发布的 CUBLAS 补丁版本中得到纠正,可在此处获得
寻找这个补丁:
补丁 1(2018 年 8 月 6 日发布)
这是一个必须在正确安装 CUDA 9.2.148 的基础上安装的补丁。
首先,您必须安装 CUDA 9.2.148。然后安装补丁。
推荐阅读
- javascript - 我可以在画布中完成css变换rotateY效果吗
- php - 使用以下代码获取数组的最小值
- c# - MassTrant - 消费者 DI 生命周期和注册选项
- python - 为什么在 elif 语句中出现语法错误?
- css - 需要 HTML 输入搞乱了我的输入
- javascript - 第 1 页,共 0 页,即使 jqgrid 包含数据
- jquery - 如何使用jQuery获取范围之间的值
- php - 在 PHP 代码上发送 HTML 电子邮件给我一个语法错误
- windows - 递归地将文件扩展名更改为小写
- swift - 如何编写一个 NSPredicate 来搜索一对多关系的日期属性?