cuda - CUDA直方图问题
问题描述
我对生成直方图的简单 CUDA 代码有疑问:
__#include <math.h>
#include <numeric>
#include <omp.h>
#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime.h>
#define BLOCK_SIZE 256
__global__ void kernel_histogram(int* dev_histogram, int* dev_values_arr, unsigned int size) {
__shared__ int temp[BLOCK_SIZE + 1];
int thread_id, thread_value;
thread_id = threadIdx.x + blockIdx.x * blockDim.x;
if (thread_id >= size) {
return;
}
temp[threadIdx.x + 1] = 0;
__syncthreads();
thread_value = dev_values_arr[thread_id];
atomicAdd(&temp[thread_value], 1);
__syncthreads();
atomicAdd(&(dev_histogram[threadIdx.x + 1]), temp[threadIdx.x + 1]);
}
int* histogram_cuda(int* values_arr, int size) {
int num_blocks = size / BLOCK_SIZE;
int* dev_histogram = 0;
int* dev_values_arr = 0;
int* histogram = (int*)malloc((BLOCK_SIZE + 1) * sizeof(int));
cudaError_t cudaStatus;
if (size % BLOCK_SIZE != 0) {
num_blocks = num_blocks + 1;
}
// allocate histogram and values_arr device memories
cudaStatus = cudaMalloc((void**)&dev_histogram,
(BLOCK_SIZE + 1) * sizeof(int));
if (cudaStatus != cudaSuccess) {
printf("ERROR: CUDA cudaMalloc() operation failed - %s\n",
cudaGetErrorString(cudaStatus));
exit(-1);
}
cudaStatus = cudaMemset(dev_histogram, 0, (BLOCK_SIZE + 1) * sizeof(int));
if (cudaStatus != cudaSuccess) {
printf("ERROR: CUDA cudaMemset() operation failed - %s\n",
cudaGetErrorString(cudaStatus));
exit(-1);
}
cudaStatus = cudaMalloc((void**)&dev_values_arr, size * sizeof(int));
if (cudaStatus != cudaSuccess) {
printf("ERROR: CUDA cudaMalloc() operation failed - %s\n",
cudaGetErrorString(cudaStatus));
exit(-1);
}
// copy values_arr memory in host to device
cudaStatus = cudaMemcpy(dev_values_arr, values_arr, size * sizeof(int),
cudaMemcpyHostToDevice);
if (cudaStatus != cudaSuccess) {
printf("ERROR: CUDA cudaMemcpy() operation failed - %s\n",
cudaGetErrorString(cudaStatus));
exit(-1);
}
printf("the number of blocks is %d\n\n", num_blocks);
// calculate histogram on the gpu
kernel_histogram << <num_blocks, BLOCK_SIZE >> > (dev_histogram, dev_values_arr,
size);
// copy histogram memory in device to host
cudaStatus = cudaMemcpy(histogram, dev_histogram,
(BLOCK_SIZE + 1) * sizeof(int), cudaMemcpyDeviceToHost);
if (cudaStatus != cudaSuccess) {
printf("ERROR: CUDA cudaMemcpy() operation failed - %s\n",
cudaGetErrorString(cudaStatus));
exit(-1);
}
// free device memory
cudaFree(dev_histogram);
cudaFree(dev_values_arr);
return histogram;
}
int main(int argc, char* argv[]) {
unsigned int size = 21;
int* histogram;
int values_arr[] = { 2, 2, 2, 2, 2, 2, 2, 4, 5, 5, 5, 5, 7, 7, 7, 7, 19, 20, 21, 100, 256 };
histogram = histogram_cuda(values_arr, size);
for (int i = 1; i < BLOCK_SIZE + 1; i++) {
if (histogram[i] > 0) {
printf("%d : %d\n", i, histogram[i]);
}
}
}
直方图用于记录输入中存在的值的数量,允许的值为 1 到 256。每个块最多有 256 个线程。我试图限制跨块的总线程数,以便每个线程记录直方图中一个值的出现。
如果我使用“values_arr = { 2, 2, 2, 2, 2, 2, 2, 4, 5, 5, 5, 5, 7, 7, 7, 7, 19, 20, 21, 100, 256 }”这意味着大小是 21,我得到:
2:7 4:1 5:4 7:4 19:1 20:1 21:1
我正在尝试使每个值都由一个线程记录并处理所有无用的线程。此外,您发现的任何其他问题以及以最佳方式解决此问题的任何建议将不胜感激。谢谢!
解决方案
在您问题中代码的新版本中,您有两个有条件执行__syncthreads()
的调用,这在 CUDA 中是非法的,并且容易出现死锁或产生未定义的行为,具体取决于您拥有的硬件和用例。
如果我这样修改内核:
__global__ void kernel_histogram(int* dev_histogram, int* dev_values_arr, unsigned int size) {
__shared__ int temp[BLOCK_SIZE + 1];
int thread_id, thread_value;
thread_id = threadIdx.x + blockIdx.x * blockDim.x;
temp[threadIdx.x + 1] = 0;
// Synchronization is unconditional
__syncthreads();
// Load is performed conditionally
if (thread_id < size) {
thread_value = dev_values_arr[thread_id];
atomicAdd(&temp[thread_value], 1);
}
// Synchronization is unconditional
__syncthreads();
atomicAdd(&(dev_histogram[threadIdx.x + 1]), temp[threadIdx.x + 1]);
}
我得到这个输出:
the number of blocks is 1
2 : 7
4 : 1
5 : 4
7 : 4
19 : 1
20 : 1
21 : 1
100 : 1
256 : 1
这看起来更像我眼中的预期。
推荐阅读
- javascript - 尝试使用类组件中的方法更新状态中的道具
- selenium-webdriver - 在 Azure DevOps 管道上运行的 SpecFlow BDD UI 测试
- spring-boot - 如果在 spring boot hibernate 中违反 saveAll()
- python - asyncio 不正确地警告流对象被垃圾收集;显式调用“stream.close()”
- rust - 有没有一种简单的方法来找出一个 Vector 是否在 Rust 中被 None 填充?
- python-3.x - 尝试使用 python3 从 .pdf 文件中提取地理坐标
- c - 验证整数值;c程序
- python - 无法在python中比较两个对象的数组
- react-native - 使用本机反应中的数据导航到另一个页面
- html - CSS多级下拉菜单不起作用