首页 > 解决方案 > Python C-API:PyDict_GetItem 上的分段错误,可能的参考问题?

问题描述

我在 C 中引用了一个 Python 字典列表。我正在编写一个函数来计算列表的两个成员之间的点积:

PyObject *handle; // reference to a list of dictionaries
virtual float dot_product () (unsigned i, unsigned j) const {
    // dot product of handle[i] and handle[j]
    PyObject *a = (PyObject*)PyList_GetItem(handle, (Py_ssize_t)i);
    PyObject *b = (PyObject*)PyList_GetItem(handle, (Py_ssize_t)j);
    PyObject *key, *a_value;
    Py_ssize_t pos = 0;
    double dot_product = 0;
    while (PyDict_Next(a, &pos, &key, &a_value)) {
        PyObject* b_value = PyDict_GetItem(b, key);
        if (b_value != NULL){
            dot_product += PyFloat_AsDouble(a_value) * PyFloat_AsDouble(b_value);
        }
    }
    return dot_product;
}

这会导致分段错误。使用 gdb 进行调试,看来分段错误是由 PyDict_GetItem(b, key) 引起的。这让我怀疑我的引用计数有问题。

阅读有关Reference Counts的文档后,似乎上述代码中的所有引用都是借用的,所以我认为没有必要使用 Py_INCREF 或 Py_DECREF ......但我很容易出错。上面的代码中是否有我需要使用 Py_INCREF 或 Py_DECREF 的地方?

编辑:我应该注意我已经完成了检查以确保 a 和 b 不为空,并且还检查以确保 i 和 j 不超过列表的大小。我从问题中的代码中删除了这些检查以使其更简单。——</p>

标签: pythonc++python-c-api

解决方案


检查您的返回值。两者ab都可以是NULLifi和分别超过 所引用j的长度。假设传递的不是指针,在不确认该假设的情况下取消引用它,这将导致立即的段错误。listhandlePyDict_GetItem dictNULL

您的主要问题是确定如何报告错误。C++ 异常适用于 C++,但 Python 无法理解它,除非您捕获它并将其转换为 Python 级别的异常。无论如何,在你弄清楚之前,返回NaN表示失败:

#include <cmath>

PyObject *handle; // reference to a list of dictionaries
virtual float dot_product () (unsigned i, unsigned j) const {
    // dot product of handle[i] and handle[j]
    PyObject *key, *a_value;
    Py_ssize_t pos = 0;
    double dot_product = 0;

    // Check both indices are valid
    PyObject *a = (PyObject*)PyList_GetItem(handle, (Py_ssize_t)i);
    if (!a) return NAN;

    PyObject *b = (PyObject*)PyList_GetItem(handle, (Py_ssize_t)j);
    if (!b) return NAN;

    // Test if you actually got dicts
    if (!PyDict_Check(a) || !PyDict_Check(b)) return NAN;

    while (PyDict_Next(a, &pos, &key, &a_value)) {
        PyObject* b_value = PyDict_GetItem(b, key);
        if (b_value != NULL){
            // Check that both values are really Python floats and extract C double
            double a_val = PyFloat_AsDouble(a_value);
            if (a_val == -1.0 && PyErr_Occurred()) return NAN;

            double b_val = PyFloat_AsDouble(b_value);
            if (b_val == -1.0 && PyErr_Occurred()) return NAN;

            dot_product += a_val * b_val;
        }
    }
    return dot_product;
}

推荐阅读