首页 > 解决方案 > 编程嵌套 numba.cuda 函数调用

问题描述

Numba 和 CUDA 菜鸟在这里。我希望能够让一个numba.cuda函数以编程方式从设备调用另一个函数,而无需将任何数据传回主机。例如,给定设置

from numba import cuda

@cuda.jit('int32(int32)', device=True)
def a(x):
    return x+1

@cuda.jit('int32(int32)', device=True)
def b(x):
    return 2*x

我希望能够定义一个组合内核函数,例如

@cuda.jit('void(int32, __device__, int32)')
def b_comp(x, inner, result):
    y = inner(x)
    result = b(y)

并成功获得

b_comp(1, a, result)
assert result == 4

理想情况下,我希望b_comp在编译后接受不同的函数参数[例如,在上述调用之后,仍然接受b_comp(1, b, result)]——但是函数参数在编译时固定的解决方案仍然对我有用。

根据我的阅读,CUDA 似乎支持传递函数指针。 这篇帖子表明numba.cuda没有这样的支持,但帖子没有说服力,而且也是一岁。numba.cuda 中支持的 Python 页面没有提到函数指针支持。但它链接到numba页面中支持的 Python,这清楚地表明numba.jit() 确实支持函数作为参数,尽管它们在编译时得到修复。如果numba.cuda.jit()像我上面所说的那样做同样的事情,那就行了。在那种情况下,在为 指定签名时comp,我应该如何声明变量类型?或者我可以使用numba.cuda.autojit()吗?

如果numba不支持任何这样的直接方法,元编程是一个合理的选择吗?例如,一旦我知道该inner函数,我的脚本就可以创建一个包含组成这些特定函数的 python 函数的新脚本,然后 apply numba.cuda.jit(),然后导入结果。看起来很复杂,但这是numba我能想到的唯一基于其他的选项。

如果numba根本无法做到这一点,或者至少没有严重的麻烦,我会很高兴得到一个提供一些细节的答案,以及像“切换到 PyCuda”这样的建议。

标签: pythoncudanumba

解决方案


这对我有用:

  1. 最初没有装饰我的功能cuda.jit,因此它们仍然具有该__name__属性
  2. 获取__name__属性
  3. 现在cuda.jit通过直接调用装饰器来应用我的函数
  4. 为字符串中的组合函数创建python,并将其传递给exec

确切的代码:

from numba import cuda
import numpy as np


def a(x):
    return x+1

def b(x):
    return 2*x


# Here, pretend we've been passed the inner function and the outer function as arguments
inner_fun = a
outer_fun = b

# And pretend we have noooooo idea what functions these guys actually point to
inner_name = inner_fun.__name__
outer_name = outer_fun.__name__

# Now manually apply the decorator
a = cuda.jit('int32(int32)', device=True)(a)
b = cuda.jit('int32(int32)', device=True)(b)

# Now construct the definition string for the composition function, and exec it.
exec_string = '@cuda.jit(\'void(int32, int32[:])\')\n' \
              'def custom_comp(x, out_array):\n' \
              '    out_array[0]=' + outer_name + '(' + inner_name + '(x))\n'

exec(exec_string)

out_array = np.array([-1])
custom_comp(1, out_array)
print(out_array)

正如预期的那样,输出是

[4]

推荐阅读