首页 > 解决方案 > 为什么在 cython 函数中请求 numpy 数组的形状时,我得到 8 个维度?

问题描述

我有以下功能,

%%cython  
cdef cytest(double[:,:] arr):
  return print(arr.shape)      
def pytest(arr):
  return cytest(arr)  

我使用以下 numpy 数组运行 pytest,

dummy = np.ones((2,2))  
pytest(dummy)  

我得到以下结果,

[2, 2, 0, 0, 0, 0, 0, 0]

标签: pythoncython

解决方案


这是因为在 中C,数组的形状是固定的。数组的最大维数cython是 8。Cython 将数组的所有维数存储在这个固定长度的数组中。

这可以通过这样做来验证:

%%cython  
cdef cytest(double[:,:,:,:,:,:,:,:,:] arr): # works up to 8 ':'
    return arr.shape  
def pytest(arr):
    return cytest(arr)

编译它时,它会抛出这个错误:

Error compiling Cython file:
------------------------------------------------------------
...
cdef cytest(double[:,:,:,:,:,:,:,:,:] arr):
                  ^
------------------------------------------------------------

/path/to/_cython_magic_9a9aea2a10d5eb901ad6987411e371dd.pyx:1:19: More dimensions than the maximum number of buffer dimensions were used.

这实质上意味着预设的最大维数为 8,我假设您可以通过更改 cython_magic 源文件来更改它。


推荐阅读