python - 如何编译具有可变输入类型的 numba jit'ed 函数?
问题描述
假设我有一个函数可以同时接受一个int
或一个None
类型作为输入参数
import numba as nb
import numpy as np
jitkw = {"nopython": True, "nogil": True, "error_model": "numpy", "fastmath": True}
@nb.jit("f8(i8)", **jitkw)
def get_random(seed=None):
np.random.seed(None)
out = np.random.normal()
return out
我希望函数简单地返回一个正态分布的随机数。如果我想要可重现的结果,种子应该是int
.
get_random(42)
>>> 0.4967141530112327
get_random(42)
>>> 0.4967141530112327
get_random(42)
>>> 0.4967141530112327
如果我想要随机数,seed
应保留为None
. 但是,如果我不传递参数(因此种子默认为None
)或显式传递seed=None
,那么 numba 会引发TypeError
get_random()
>>> TypeError: No matching definition for argument type(s) omitted(default=None)
get_random(None)
>>> TypeError: No matching definition for argument type(s) omitted(default=None)
在这种情况下,我该如何编写函数,仍然声明签名和使用nopython
模式?
我的 numba 版本是 0.43.1
解决方案
第一个问题是 nopython 模式下的 numba 仅接受(从版本 0.43.1 开始)np.random.seed
:仅使用整数参数。
因此,很遗憾,您无法通过None
.
第二个问题是(据我所知)没有告诉 numba 如何处理缺失值的“单一”签名,但是您可以使用两个签名(是的,它非常冗长):
import numba as nb
import numpy as np
jitkw = {"nopython": True, "nogil": True, "error_model": "numpy", "fastmath": True}
@nb.jit(
[nb.types.float64(nb.types.misc.Omitted(None)),
nb.types.float64(nb.types.int64)],
**jitkw)
def get_random(seed=None):
return np.random.normal()
只是关于签名的两个部分的简短说明:
- 如果省略参数,则告诉 numba 用作默认
nb.types.float64(nb.types.misc.Omitted(None))
类型None
- 是
nb.types.float64(nb.types.int64)
需要整数的签名。
就我个人而言,我不会指定签名,只是让 numba 弄清楚。显式签名在 numba 中很少值得,而且更常见的是,它们会导致代码变慢且不灵活。
推荐阅读
- drupal-8 - 如何在drupal8中实现邮政编码和城市的依赖下拉菜单?
- java - 存在 JNI API 调用时的 JUnit 测试覆盖率
- python - 在 Django 中保存多对多字段的实例
- python - 将 DataFrame.from_records 添加到现有的 df
- while-loop - 如何通过在函数中使用 while 循环返回包含异常的数组的总和?
- python - 我无法显示标题、公司、位置、日期。我试过 end='',但它不起作用
- java - Fat Jar 导出:找不到“logging/bin/default”的类路径条目
- javascript - 如何从轮播中的图像以表单形式提交信息
- reactjs - 如何在reactjs的输入字段内添加测量单位
- javascript - 动态参数的 URL 正则表达式匹配