首页 > 解决方案 > Numba:选择数组的 N 维子集

问题描述

让我们有一些data选择index

import numpy as np
shape = 50, 50, 50, 50
data = np.linspace(0.0, 10.0 - 1.0e-3, int(np.prod(shape))).reshape(shape)
index = (data > data.mean()).astype(bool)

现在,在常规中numpy,我可以简单地

In: data[index]
Out: array([4.9995008, 4.9995024, 4.999504 , ..., 9.9989968, 9.9989984,
   9.999    ])

numba如果我不事先展平阵列,显然不能这样做,这有点贵:

@jit(nopython=True, parallel=False)
def select_4d(data, index):
    return data[index]
select_4d(data, index)

导致错误(下)。有没有一种廉价的方法来解决这个问题而不使阵列变平?

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function getitem>) found for signature:
 
 >>> getitem(array(float64, 4d, C), array(bool, 4d, C))
 
There are 22 candidate implementations:
  - Of which 20 did not match due to:
  Overload of function 'getitem': File: <numerous>: Line N/A.
    With argument(s): '(array(float64, 4d, C), array(bool, 4d, C))':
   No match.
  - Of which 2 did not match due to:
  Overload in function 'GetItemBuffer.generic': File: numba/core/typing/arraydecl.py: Line 162.
    With argument(s): '(array(float64, 4d, C), array(bool, 4d, C))':
   Rejected as the implementation raised a specific error:
     TypeError: unsupported array index type array(bool, 4d, C) in [array(bool, 4d, C)]
  raised from /home/saman/anaconda3/envs/myenv3/lib/python3.9/site-packages/numba/core/typing/arraydecl.py:68

标签: pythonnumpynumba

解决方案


推荐阅读