python - 从扩展函数返回后使用自定义 Python 扩展的堆损坏 (0xC0000374)
问题描述
我正在为我的 Python 脚本编写一个扩展模块,但无法弄清楚这个堆损坏的来源。扩展模块的使用应该是创建一个 Numpy 数组,用给定一些条件计算的一些值填充它,并将 Numpy 数组返回给 Python。有一个 Python 函数(在下面的示例中get_data()
),它包装了扩展函数并从中接收 Numpy 数组。当 Python 函数返回时,会发生错误。
当我使用"OOO"
而不是"NNN"
作为Py_BuildValue()
.
因此,我认为这是关于引用计数的问题。如果您查看下面的示例代码,是否有人熟悉这个问题?
开始使用 WinDbg,但完全是初学者。
注意:不幸的是不能共享原始代码,这当然使这有点困难。下面的示例代码对我没有任何错误,但是它使用了相同的概念,问题可能出在其中。
模块.cpp
#define PY_SSIZE_T_CLEAN
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
#include <Python.h>
#include <numpy\arrayobject.h>
// Helper for making some exemplary changes on data array
void applySomeChangesOnData(double* data, double factor, size_t nRows, size_t nColumns) {
for (int m = 0; m < nRows; m++) {
for (int n = 0; n < nColumns; n++) {
data[m*nColumns+n] = factor * (m + n);
}
}
}
PyObject* createArrays(PyObject *self, PyObject *args) {
// Variables needed for parsing inputs
PyObject* tupleShape;
// Parse input
if (!PyArg_ParseTuple(args, "O!", &PyTuple_Type, &tupleShape)) {
// Error when parsing
PyErr_SetString(PyExc_TypeError, "Bad input type(s)");
Py_RETURN_NONE;
}
// Allow only 2D arrays for this example
if ((size_t)PyTuple_Size(tupleShape) != 2) {
PyErr_SetString(PyExc_ValueError, "Array must be 2D");
Py_RETURN_NONE;
}
// Allocate data array
size_t nRows = PyLong_AsLong(PyTuple_GetItem(tupleShape, 0));
size_t nColumns = PyLong_AsLong(PyTuple_GetItem(tupleShape, 1));
size_t nElements = nRows * nColumns;
double* data1 = PyMem_New(double, nElements);
double* data2 = PyMem_New(double, nElements);
double* data3 = PyMem_New(double, nElements);
// Modify data array
applySomeChangesOnData(data1, 1, nRows, nColumns);
applySomeChangesOnData(data2, 10, nRows, nColumns);
applySomeChangesOnData(data3, 100, nRows, nColumns);
// Prepare shape information
size_t nDims = 1;
npy_intp* shape = new npy_intp[nDims];
shape[0] = nElements;
// Create output array and set OWNDATA flag for proper deallocation
PyArrayObject* arr1 = reinterpret_cast<PyArrayObject*>(PyArray_SimpleNewFromData(nDims, shape, NPY_DOUBLE, data1));
PyArrayObject* arr2 = reinterpret_cast<PyArrayObject*>(PyArray_SimpleNewFromData(nDims, shape, NPY_DOUBLE, data2));
PyArrayObject* arr3 = reinterpret_cast<PyArrayObject*>(PyArray_SimpleNewFromData(nDims, shape, NPY_DOUBLE, data3));
if (!arr1 || !arr2 || !arr3) {
PyErr_SetString(PyExc_RuntimeError, "Failed when creating output array");
Py_RETURN_NONE;
}
PyArray_ENABLEFLAGS(arr1, NPY_ARRAY_OWNDATA);
PyArray_ENABLEFLAGS(arr2, NPY_ARRAY_OWNDATA);
PyArray_ENABLEFLAGS(arr3, NPY_ARRAY_OWNDATA);
// Some clean-up
delete[] shape; shape = NULL;
// Return multiple outputs
PyObject* ret = Py_BuildValue("NNN", arr1, arr2, arr3);
return ret;
}
static PyMethodDef extension_methods[] = {
{ "create_arrays", (PyCFunction)createArrays, METH_VARARGS, nullptr },
{ nullptr, nullptr, 0, nullptr },
};
static PyModuleDef extension_module = { PyModuleDef_HEAD_INIT, "extension", "Some docs...", 0, extension_methods };
PyMODINIT_FUNC PyInit_extension() {
import_array();
return PyModule_Create(&extension_module);
}
安装程序.py
import os
import sys
from setuptools import setup, Extension, find_packages
# Get installation path of Python interpreter
(path_interpreter, _) = os.path.split(sys.executable)
ext_module = Extension('example_package.extension',
sources=['example_package/module.cpp',],
include_dirs=[os.path.join(path_interpreter, 'Lib/site-packages/numpy/core/include')],
extra_compile_args=['/Zi'],
extra_link_args=['/DEBUG'])
setup(
name='example_package',
version='0.1',
ext_modules=[ext_module],
)
debug_extension.py
import numpy as np
from example_package.extension import create_arrays
shape = (4000, 6000)
def get_data():
# Write 5 pairs of arrays into this list
list_pairs = []
for i in range(50):
# Get 3 arrays from extension
(arr1, arr2, arr3) = create_arrays(shape)
# Reshape, transpose and later combine them using np.stack()
arr1 = np.reshape(arr1, shape).transpose()
arr2 = np.reshape(arr2, shape).transpose()
arr3 = np.reshape(arr3, shape).transpose()
list_pairs.append([
np.stack([arr1, arr2], axis=0),
np.stack([arr1, arr3], axis=1),
])
return list_pairs
list_pairs = get_data()
解决方案
推荐阅读
- sql-server - 使用等效的 SET STATISTICS 监控 SSIS 数据流
- javascript - 只需使用 JS 在 DOM 中插入一些结构化的自定义 HTML
- amazon-web-services - 无法将 KMS 密钥应用到 AWS CloudWatch 日志组
- python - 具有相关名称的反向查找
- c# - 布尔值的 C# 条件不一样
- python - 在 Python 中撤销 celery 任务时如何退出 selenium webdriver
- freemarker - Freemaker 以下已评估为 null 或缺失
- amazon-web-services - [返回结果不正确]AWS Redshift(RedShift)中Join语句中Limit不正确
- ruby-on-rails - ActiveStorage:从 s3 打开已处理的变体
- python - Excel图表中的下标字母