首页 > 解决方案 > 如何将 2D numpy 数组转换为 Cython 中的指针数组?

问题描述

我想使用 Cython 包装具有以下签名的 C/C++ 函数。

// my_tool.h
float func(const float** points, int n, int m) {
    float result = 0.;
    // ...
    for (int i=0; i<n; ++i) {
        for (int j=0; j<m; ++j) {
            // ... points[i][j]
        }
    }
    return result
}

包装器将收到一个 2D-numpy 数组。数组不需要是连续的(例如,compute也可以采用数组切片:)arr[:, 2:-2]

# my_tool.pyx
import numpy as np
cimport cython

cdef extern from "my_tool.h":
    int func(const float** points, int n_points)

def compute(float[:,:] points):
    # Assure float-ndarray.
    points = np.asarray(points, dtype=float)
    # Create a memoryview.
    cdef float[:,:] points_view = points
    # >>> The following line will lead to a syntax error:
    # >>> "Expected an identifier or literal"
    cdef float*[:] points_ptr = [&points_view[i] for i in points.shape[0]]
    
    ret = func(&points_ptr[0], points.shape[0])
    return ret

问题:如何将二维数组的内存视图func传递给它的签名(C 风格的指针列表)匹配?

# This is how I want to use my wrapped tool in python.
import mytool
import numpy as np
points = np.random.rand(10,2)
ret = mytool.compute(points)

更新/解决方案:这篇文章回答了这个问题。我的解决方案看起来类似于:

from cpython.mem cimport PyMem_Malloc, PyMem_Free

def compute(float[:,:] points):
    # Assure float-ndarray and create a typed memoryview.
    if False:
        points = np.asarray(points, dtype=float)
        cdef float[:,:] points_view = points
    else:
        # If a contiguous array is advantageous.
        points = np.ascontiguousarray(points, dtype=float)
        cdef float[:,::1] points_view = points

    # Create a helper container with pointers of each point.
    cdef size_t n_bytes = points.shape[0] * sizeof(float*)
    cdef float_type** point_ptrs = <float_type **>PyMem_Malloc(n_bytes)
    if not point_ptrs:
        raise MemoryError
    try:
        for i in range(points.shape[0]):
            point_ptrs[i] = &points[i, 0]
        # Call the C function that expects a float_type**
        ret = func(point_ptrs, points.shape[0])
    finally:
        PyMem_Free(point_ptrs)

标签: pythonnumpycython

解决方案


推荐阅读