python - 如果我使用 pycuda 将数组传递给 GPU 然后打印它,为什么它打印零?
问题描述
我正在尝试使用 pycuda 来加速我的神经网络(我知道 tensorflow 更容易用于 GPU 加速,我只是想先手动完成,因为我对神经网络比较陌生),但是每当我将数组传递给GPU 并让每个线程在 threadIdx 处打印出数组的值,即使我设置了数组值,它也会打印零。
我尝试使用一个非常简单的内核进行测试,它只打印一维数组的值,并且我尝试将数据类型更改为 float32。
我用于测试此问题的基本内核:
test_mod = SourceModule("""
__global__ void test(float *a)
{
printf("%d: %d\\n", threadIdx.x, a[threadIdx.x]);
}
""")
我用来创建数组和初始化内核的 python 代码:
a = np.asarray([4,2,1])
a = a.astype(np.float32)
test_module = test_mod.get_function("test")
test_module(cuda.In(a), block=(3, 1, 1))
我希望它打印一些 4、2 和 1 的顺序,但每个线程都打印一个 0。
解决方案
问题在于内核中的打印语句。格式说明%d
符用于整数。它不会正确格式化浮点值。要修复它,请像这样修改内核:
test_mod = SourceModule("""
__global__ void test(float *a)
{
printf("%d: %f\\n", threadIdx.x, a[threadIdx.x]);
}
""")
[从评论中收集答案并添加为社区 wiki 条目,以尝试将问题从 CUDA 标签的未回答队列中删除]。
推荐阅读
- html - 引导选项卡的多个数据目标不起作用
- typescript - Typescript Type 'XX' 不可分配给类型 import("/Volumes/D/test").XX'.ts(2322)
- java - 使用递归的回文链表
- flutter - 运行 Flutter Web 时的设备/浏览器详细信息
- java - 检查参数是否是对类的特定静态字段的引用
- mysql - 就 ERD 而言,这种关系是什么?
- jsp - Servlet 错误:ClassNotFoundException:HttpServletRequest
- tcsh - 创建一个包含 foreach 和 if 语句的别名
- latex - 乳胶列表 emph
- react-native - 我可以在 React Native 中为每种语言使用不同的字体吗?