首页 > 解决方案 > python使用ctypes将二维数组传递给c函数

问题描述

我想用 C 来处理一些计算。例如,我有一个添加两个矩阵的 C 函数:

// mat_add.c
#include <stdlib.h>

void matAdd(int ROW, int COL, int x[][COL], int y[][COL], int z[][COL]){
    int i, j;
    for (i = 0; i < ROW; i++){
        for (j = 0; j < COL; j++){
            z[i][j] = x[i][j] + y[j][j];
        }
    }
}

然后我将它编译成 .so 文件:
gcc -shared -fPIC mat_add.c -o mat_add.so

在python中:

# mat_add_test.py
import ctypes
import numpy as np

def cfunc(x, y):
    nrow, ncol = x.shape
    
    objdll = ctypes.CDLL('./mat_add.so')
    
    func = objdll.matAdd
    func.argtypes = [
        ctypes.c_int,
        ctypes.c_int,
        np.ctypeslib.ndpointer(dtype=np.int, ndim=2, shape=(nrow, ncol)),
        np.ctypeslib.ndpointer(dtype=np.int, ndim=2, shape=(nrow, ncol)),
        np.ctypeslib.ndpointer(dtype=np.int, ndim=2, shape=(nrow, ncol))
    ]
    func_restype = None
    
    z = np.empty_like(x)
    func(nrow, ncol, x, y, z)
    return z


if __name__ == '__main__':
    x = np.array([[1, 2], [3, 4]], dtype=np.int)
    y = np.array([[2, 2], [5, 6]], dtype=np.int)
    z = cfunc(x, y)
    print(z)
    print('end')

执行这个python文件,我得到:

$ python mat_add_test.py 
[[                  3                   4]
 [8386863780988286322 7813586346238636153]]
end

返回矩阵的第一行是正确的,但第二行是错误的。我想我没有成功更新 中的值z,但我不知道问题出在哪里。
任何人都可以帮忙吗?很感谢!

标签: pythoncctypes

解决方案


问题中二维数组的处理是正确的。唯一的问题(除了 C 代码如何索引y数组的拼写错误 -y[j][j]应该是y[i][j])是np.int所以np.int64这不对应于 C int

为了确保类型匹配,可以在两种语言中指定显式长度。

在 Python 中:使用np.int32ornp.int64显式(而不是np.int)。

在 C: 中#include <stdint.h>,然后相应地使用int32_tor int64_t(可能通过 a typedef),而不是int.

然后问题就消失了。

对于ROWand COL,这些是按值调用的,因此它不太重要(当然前提是值不会溢出)。

这里发生了什么

实际上,二维数组仍然只是内存中的一维值序列。2维只是索引它的一种方便方法。

因此,在 numpy 中,调用 C 之前的数组是(十六进制):

0000000000000001 0000000000000002 0000000000000003 0000000000000004  <== x
0000000000000002 0000000000000002 0000000000000005 0000000000000006  <== y
UUUUUUUUUUUUUUUU UUUUUUUUUUUUUUUU UUUUUUUUUUUUUUUU UUUUUUUUUUUUUUUU  <== z 

其中U表示未定义/未初始化的数据

但是在 C 代码中(假设是小端序),将数组视为 32 位,它会看到:

inputs
00000001 00000000 00000002 00000000 00000003 00000000 00000004 00000000  <== x
00000002 00000000 00000002 00000000 00000005 00000000 00000006 00000000  <== y
UUUUUUUU UUUUUUUU UUUUUUUU UUUUUUUU UUUUUUUU UUUUUUUU UUUUUUUU UUUUUUUU  <== z at start

然后 C 代码循环遍历每个元素的前 4 个元素,执行加法运算,因此产生:

00000003 00000000 00000004 00000000 UUUUUUUU UUUUUUUU UUUUUUUU UUUUUUUU  <== z at end

然后使用 64 位 int 类型返回 numpy,现在我们看到:

0000000000000003 0000000000000004 UUUUUUUUUUUUUUUU UUUUUUUUUUUUUUUU  <== output z

解释为二维数组,这是array([[3, 4], [whatever, whatever]])


推荐阅读