首页 > 解决方案 > numba 的有趣行为 - 使用 argmax() 的 guvectorized 函数

问题描述

考虑以下脚本:

from numba import guvectorize, u1, i8
import numpy as np

@guvectorize([(u1[:],i8)], '(n)->()')
def f(x, res):
    res = x.argmax()

x = np.array([1,2,3],dtype=np.uint8)
print(f(x))
print(x.argmax())
print(f(x))

运行它时,我得到以下信息:

4382569440205035030
2
2

为什么会这样?有没有办法让它正确?

标签: vectorizationnumbaargmax

解决方案


Python 没有引用,因此res = ...实际上并没有分配给输出参数,而是重新绑定了 name res。我相信 res 指向的是未初始化的内存,这就是为什么你的第一次运行会给出一个看似随机的值。

Numba 使用会改变 res 的切片语法 ( [:]) 解决此问题 - 您还需要将类型声明为数组。一个工作函数是:

@guvectorize([(u1[:], i8[:])], '(n)->()')
def f(x, res):
    res[:] = x.argmax()

推荐阅读