首页 > 解决方案 > 设置数组元素时的奇怪结果(C++/pybind11)

问题描述

我正在尝试使用 pybind11 编写 C++ 扩展。该函数最终将几个 numpy-arrays 作为输入并返回几个 numpy-arrays 作为输出。我一直在尝试直接传递目标 numpy-arrays 并在 C++ 函数中就地更改值。

我正在使用以下代码进行测试:

C++

#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"

namespace py = pybind11;

int nparraytest(py::array_t<float> A, py::array_t<float> B, py::array_t<float> C, py::array_t<float> D) {
    py::buffer_info Abuf = A.request();
    py::buffer_info Bbuf = B.request();
    py::buffer_info Cbuf = C.request();
    py::buffer_info Dbuf = D.request();

    int lxA = Abuf.shape[0];

    double* Aptr = (double*)Abuf.ptr;
    double* Bptr = (double*)Bbuf.ptr;
    double* Cptr = (double*)Cbuf.ptr;
    double* Dptr = (double*)Dbuf.ptr;

    for (int i = 0; i < lxA; i++) {
        int i30 = 3 * i;
        int i31 = i30 + 1;
        int i32 = i30 + 2;

        //optionA
        //Cptr[i30] = Aptr[i30];
        //Cptr[i31] = Aptr[i31];
        //Cptr[i32] = Aptr[i32];

        //Option B
        Cptr[i30] = Aptr[i30];
        Cptr[i31] = Aptr[i31];
        Cptr[i32] = Bptr[2];

        Dptr[i30] = Aptr[i30];
        Dptr[i31] = Aptr[i31];
        Dptr[i32] = Aptr[i32];
    }

    return lxA;
}

PYBIND11_MODULE(SeqRT, m) {
    m.def("nparraytest", &nparraytest, "Test");
}

Python

import numpy as np
import SeqRT

if __name__ == '__main__':
    #initialize arrays
    A = np.arange(15, dtype = np.float32).reshape((5,3))
    B = np.arange(15, dtype = np.float32).reshape((5,3))
    C = np.zeros((5,3), dtype = np.float32)
    D = np.zeros((5,3), dtype = np.float32)
    
    lxA = SeqRT.nparraytest(A, B, C, D)
    
    print(lxA)
    print(A)
    print(B)
    print(C)
    print(D)

现在,无论我使用选项 A 还是选项 B 中的代码,数组 A、B 和 D 总是按预期结束,以及数组 C,即

[[ 0.  1.  2.]
 [ 3.  4.  5.]
 [ 6.  7.  8.]
 [ 9. 10. 11.]
 [12. 13. 14.]]

但是,使用选项 BI 获得数组 C 的此结果

[[ 0.  1.  2.]
 [ 3.  4.  5.]
 [ 6.  7.  8.]
 [ 9.  4.  5.]
 [12. 13. 14.]]

如您所见,值 10. 和 11. 不同。事实上,将其他输入作为数组 B 时,值 4. 和 5. 也可能非常随机。相反,我希望这样:

[[ 0.  1.  2.]
 [ 3.  4.  2.]
 [ 6.  7.  2.]
 [ 9. 10.  2.]
 [12. 13.  2.]]

我不知道我的错误是什么:

标签: c++pybind11

解决方案


推荐阅读