首页 > 解决方案 > 如何在 Numba Vectorize 签名中指定元组?

问题描述

我正在定义一个函数,并希望使用 Numba Vectorize 通过 cuda 来加速它。我在使用函数签名时遇到问题。该函数将返回一个 float64 值。我想传递两个将被矢量化的 float64 值,以及一个 9 元组的 float64 值,它们将是标量。

这是我的函数头:

from numba import vectorize

@vectorize(['float64(float64, float64, UniTuple(float64, 9))'], target='cuda')
def fn_vec(E, L, fparams):
    # calculations... 
    return result

但这给出了一个错误:

TypeError: data type "(float64 x 9)" not understood

我尝试了许多变体,包括 (float64, ..., float64) 代替 UniTuple(),但无法正常工作。我该怎么做呢?

标签: pythontypesvectorizationnumba

解决方案


如何在 Numba Vectorize 签名中指定元组?

numba.vectorize函数中不能使用元组。那是因为vectorize向量化了这些类型数组的代码。

因此,使用float, float, tuple签名会创建一个函数,该函数需要两个包含浮点数的数组和一个包含元组的数组。问题是包含元组的数组没有 dtype - 如果您使用结构化数组而不是包含元组的数组,它可能会起作用,但我没有尝试过。

如何在 Numbajit签名中指定元组?

UniTuple在 numba 签名中指定 a 的正确方法是使用numba.types.containers.UniTuple. 在你的情况下:

nb.types.containers.UniTuple(nb.types.float64, 9)

所以正确的签名应该是这样的:

import numba as nb

@nb.njit(
    nb.types.float64(
        nb.types.float64, 
        nb.types.float64, 
        nb.types.containers.UniTuple(nb.types.float64, 9)))
def func(f1, f2, ftuple):
    # ...
    return f1

我经常避免显式输入我的 numba 函数 - 但是当我这样做时,我发现它非常有用numba.typeof,例如:

>>> nb.typeof((1.0, ) * 9)
tuple(float64 x 9)

>>> type(nb.typeof((1.0, ) * 9))
numba.types.containers.UniTuple

>>> help(type(nb.typeof((1.0, ) * 9)))  # I shortened the result:
Help on class UniTuple in module numba.types.containers:

class UniTuple(BaseAnonymousTuple, _HomogeneousTuple, numba.types.abstract.Sequence)
 |  UniTuple(*args, **kwargs)
 |  
 |  Type class for homogeneous tuples.
 |  
 |  Methods defined here:
 |  
 |  __init__(self, dtype, count)
 |      Initialize self.  See help(type(self)) for accurate signature.

所以信息就在那里:它是numba.types.containes.UniTuple,你用两个参数实例化它,dtype(这里float64)和数字(在这种情况下9)。

如果您只想对浮点数组进行矢量化

如果你不想为元组参数向量化函数,你可以简单地在另一个函数中创建向量化函数并在那里调用它:

import numba as nb
import numpy as np

def func(E, L, fparams):
    @nb.vectorize(['float64(float64, float64)'])
    def fn_vec(e, l):
        return e + l + fparams[1]  # just to illustrate that the tuple is available
    return fn_vec(E, L)

这使得元组在vectorized 函数中可用。但是,它必须创建内部函数并在每次调用外部函数时对其进行编译,因此这实际上可能会更慢。我也不确定这是否适用于target="cuda",您可能需要自己测试。


推荐阅读