首页 > 解决方案 > 与 numba 的 prange 并行化并沿元组传递

问题描述

我尝试并行化如下,紧跟文档中的示例

@numba.jit(nopython=True)
def foo(uIdx, grids):
    return uIdx

@numba.jit(nopython=True, parallel=True)
def bar(grid, grids):

    LIdxGrid = np.zeros(len(grid))

    for uIdx in numba.prange(len(grid)):
            LIdxGrid[uIdx] = foo(uIdx, grids)
    return LIdxGrid



if __name__ == '__main__':
    import numpy as np
    grid = np.arange(12)
    grids = (grid, grid)
    bar(grid, grids)

但它似乎不起作用。这个问题似乎源于传递grids(甚至没有在最终foo函数中使用它)foo如果我在and中删除该引用bar,它将起作用:

bar(grid, 0)
Out[47]: array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.])

我该如何解决/解决这个问题?

完整的追溯是

Traceback (most recent call last):
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/errors.py", line 491, in new_error_context
    yield
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/lowering.py", line 216, in lower_block
    self.lower_inst(inst)
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/lowering.py", line 365, in lower_inst
    func(self, inst)
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/npyufunc/parfor.py", line 119, in _lower_parfor_parallel
    index_var_typ)
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/npyufunc/parfor.py", line 691, in call_parallel_gufunc
    sout, {})
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/npyufunc/parallel.py", line 251, in build_gufunc_wrapper
    cache=cache)
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/npyufunc/wrappers.py", line 460, in build_gufunc_wrapper
    return wrapcls(py_func, cres, sin, sout, cache).build()
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/npyufunc/wrappers.py", line 411, in build
    self._build_wrapper(wrapperlib, wrapper_name)
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/npyufunc/wrappers.py", line 372, in _build_wrapper
    arg_steps, i, step_offset, typ, sym, sym_dim)
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/npyufunc/wrappers.py", line 614, in __init__
    "argument #{1}".format(typ, i + 1))
TypeError: scalar type tuple(array(int64, 1d, C) x 2) given for non scalar argument #2
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-46-b6a12a1ce616>", line 3, in <module>
    bar(grid, grids)
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/dispatcher.py", line 360, in _compile_for_args
    raise e
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/dispatcher.py", line 311, in _compile_for_args
    return self.compile(tuple(argtypes))
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/dispatcher.py", line 618, in compile
    cres = self._compiler.compile(args, return_type)
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/dispatcher.py", line 83, in compile
    pipeline_class=self.pipeline_class)
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/compiler.py", line 871, in compile_extra
    return pipeline.compile_extra(func)
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/compiler.py", line 365, in compile_extra
    return self._compile_bytecode()
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/compiler.py", line 802, in _compile_bytecode
    return self._compile_core()
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/compiler.py", line 789, in _compile_core
    res = pm.run(self.status)
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/compiler.py", line 251, in run
    raise patched_exception
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/compiler.py", line 243, in run
    stage()
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/compiler.py", line 676, in stage_nopython_backend
    self._backend(lowerfn, objectmode=False)
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/compiler.py", line 626, in _backend
    lowered = lowerfn()
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/compiler.py", line 613, in backend_nopython_mode
    self.flags)
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/compiler.py", line 990, in native_lowering_stage
    lower.lower()
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/lowering.py", line 135, in lower
    self.lower_normal_function(self.fndesc)
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/lowering.py", line 176, in lower_normal_function
    entry_block_tail = self.lower_function_body()
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/lowering.py", line 201, in lower_function_body
    self.lower_block(block)
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/lowering.py", line 216, in lower_block
    self.lower_inst(inst)
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/contextlib.py", line 99, in __exit__
    self.gen.throw(type, value, traceback)
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/errors.py", line 499, in new_error_context
    six.reraise(type(newerr), newerr, tb)
  File "/home/foo/anaconda3/envs/myenv3/lib/python3.6/site-packages/numba/six.py", line 659, in reraise
    raise value
numba.errors.LoweringError: Failed at nopython (nopython mode backend)
scalar type tuple(array(int64, 1d, C) x 2) given for non scalar argument #2
File "<ipython-input-44-ec97cbf0b87b>", line 9:
def bar(grid, grids):
    <source elided>
    LIdxGrid = np.zeros(len(grid))
    ^
[1] During: lowering "id=7[LoopNest(index_variable = parfor_index.317, range = (0, grid_size0.315, 1))]{51: <ir.Block at <ipython-input-44-ec97cbf0b87b> (9)>}Var(parfor_index.317, <ipython-input-44-ec97cbf0b87b> (9))" at <ipython-input-44-ec97cbf0b87b> (9)
-------------------------------------------------------------------------------
This should not have happened, a problem has occurred in Numba's internals.

标签: pythonnumpyparallel-processingnumba

解决方案


对引用计数项目(如np.ndarrays)的支持是相当新的(从 numba 0.39 开始),我不确定是否使用tuples of ref。计数的项目已经有效。Afaiktuple的参考文献。尚不支持计数项目。因此,为了确保您的代码有效,您必须将 替换tuplelist

if __name__ == '__main__':
    import numpy as np
    grid = np.arange(12)
    grids = [grid, grid]
    bar(grid, grids)

并确保您安装了 numba 0.39 版!否则这将无法正常工作。
当然列表不是元组,所以这只是一种解决方法。但是没有其他方法可以解决这个问题,只要 ref 的元组。不完全支持计数项目。


推荐阅读