首页 > 解决方案 > 为什么 numba 中的 guvectorize 返回奇怪的输出?

问题描述

例如,我无法理解以下函数的输出:

@nb.guvectorize(["void(float64, float64, float64, float64)"], "(),()->(),()")
def add_subtract(a, b,  res1, res2):
    res1 = a + b
    res2 = a - b

当我调用它时,

add_subtract(np.array([3.0,5.0]), np.array([6.0,7.0]))

奇怪的是,它会显示:

(array([6.95307160e-310, 1.23643049e-311]), array([0., 0.]))

看来这个函数试图返回 0。为什么这个函数没有返回 [[9,-3],[12,-2]]?

标签: pythonnumpynumba

解决方案


输出数组未初始化地传递给函数。如果为res1res2locals 分配新值,则将数组替换为新数组。这不会更改传递给函数的数组,这些数组保持未初始化。

相反,您需要替换它们的内容。以下示例有效:

@nb.guvectorize("void(float64[:], float64[:], float64[:], float64[:])",
                "(),()->(),()")
def add_subtract(a, b, res1, res2):
    res1[:] = a + b
    res2[:] = a - b

推荐阅读