python - 如何在 Numba Vectorize 签名中指定元组?
问题描述
我正在定义一个函数,并希望使用 Numba Vectorize 通过 cuda 来加速它。我在使用函数签名时遇到问题。该函数将返回一个 float64 值。我想传递两个将被矢量化的 float64 值,以及一个 9 元组的 float64 值,它们将是标量。
这是我的函数头:
from numba import vectorize
@vectorize(['float64(float64, float64, UniTuple(float64, 9))'], target='cuda')
def fn_vec(E, L, fparams):
# calculations...
return result
但这给出了一个错误:
TypeError: data type "(float64 x 9)" not understood
我尝试了许多变体,包括 (float64, ..., float64) 代替 UniTuple(),但无法正常工作。我该怎么做呢?
解决方案
如何在 Numba Vectorize 签名中指定元组?
在numba.vectorize
函数中不能使用元组。那是因为vectorize
向量化了这些类型数组的代码。
因此,使用float, float, tuple
签名会创建一个函数,该函数需要两个包含浮点数的数组和一个包含元组的数组。问题是包含元组的数组没有 dtype - 如果您使用结构化数组而不是包含元组的数组,它可能会起作用,但我没有尝试过。
如何在 Numba
jit
签名中指定元组?
UniTuple
在 numba 签名中指定 a 的正确方法是使用numba.types.containers.UniTuple
. 在你的情况下:
nb.types.containers.UniTuple(nb.types.float64, 9)
所以正确的签名应该是这样的:
import numba as nb
@nb.njit(
nb.types.float64(
nb.types.float64,
nb.types.float64,
nb.types.containers.UniTuple(nb.types.float64, 9)))
def func(f1, f2, ftuple):
# ...
return f1
我经常避免显式输入我的 numba 函数 - 但是当我这样做时,我发现它非常有用numba.typeof
,例如:
>>> nb.typeof((1.0, ) * 9)
tuple(float64 x 9)
>>> type(nb.typeof((1.0, ) * 9))
numba.types.containers.UniTuple
>>> help(type(nb.typeof((1.0, ) * 9))) # I shortened the result:
Help on class UniTuple in module numba.types.containers:
class UniTuple(BaseAnonymousTuple, _HomogeneousTuple, numba.types.abstract.Sequence)
| UniTuple(*args, **kwargs)
|
| Type class for homogeneous tuples.
|
| Methods defined here:
|
| __init__(self, dtype, count)
| Initialize self. See help(type(self)) for accurate signature.
所以信息就在那里:它是numba.types.containes.UniTuple
,你用两个参数实例化它,dtype
(这里float64
)和数字(在这种情况下9
)。
如果您只想对浮点数组进行矢量化
如果你不想为元组参数向量化函数,你可以简单地在另一个函数中创建向量化函数并在那里调用它:
import numba as nb
import numpy as np
def func(E, L, fparams):
@nb.vectorize(['float64(float64, float64)'])
def fn_vec(e, l):
return e + l + fparams[1] # just to illustrate that the tuple is available
return fn_vec(E, L)
这使得元组在vectorize
d 函数中可用。但是,它必须创建内部函数并在每次调用外部函数时对其进行编译,因此这实际上可能会更慢。我也不确定这是否适用于target="cuda"
,您可能需要自己测试。
推荐阅读
- python - 无法将 JSON 加载到 MongoDB,因为 MongoDB 特定符号
- java - Java 数据库 - 绝对 URL
- java - 如何修复 Jhipster aws 子生成器因“无法解析命令行选项:无法识别的选项:-ntp”错误而失败
- php - 子文件夹的 PHP 命名空间问题
- html - 如何为文本静音类向右浮动/向右拉动
- javascript - sSearch:过滤记录不从首字母开始过滤,而是包含
- lua - 如何在 Lua 中替换部分字符串?
- python - 将浮点值分配给数组时出现以下错误:(“'numpy.float64' 对象不支持项分配”,'发生在索引 4')
- junit - 如何使用 JUnit 测试运行时异常场景
- vba - 网页浏览器过期 VB