首页 > 解决方案 > 如何在 cython 的函数参数中键入函数

问题描述

我已经制作了一个在 python 中进行优化的函数(我们称之为optimizer)。它需要将函数优化(我们称之为objective)作为函数参数之一。objective是一个接受一维np.ndarray并返回数字的函数(这与C++float中的相同?)。double

我已经阅读了这篇文章,但我不确定它是否与我的问题以及我使用时的问题相同,但在编译过程ctypedef int (*f_type)(int, str)中出现错误。Cannot convert 'f_type' to Python object它仅适用于 C 函数吗?如何键入 python 函数?

编辑:我的代码如何:

cpdef optimizer(objective, int num_particle, int dim,
             np.ndarray[double, ndim=1] lower_bound,
             np.ndarray[double, ndim=1] upper_bound):

    cdef double min_value
    cdef np.ndarray[double, ndim=2] positions = np.empty((num_particle,dim), dtype=np.double)
    cdef np.ndarray[double, ndim=1] fitness = np.empty(num_particle, dtype=np.double)
    cdef int i, j

    # do lots of stuff not shown here
    # involve the following code:
    for i in range(num_particle):
        fitness[i] = objective(positions[i])

    return min_value

我想知道是否可以键入objective以使代码运行得更快。

标签: pythoncythonstatic-typing

解决方案


我收到错误消息

Cannot convert Python object argument to type 'f_type'

我认为这比你声称得到的更有意义 - 你试图将 Python 对象传递给函数。请确保您报告的错误消息是您的代码实际生成的错误消息。您对所采用类型的描述objective也与您显示的代码不匹配。


然而,一般来说:不,你不能给你的目标函数一个类型说明符来加速它。一个通用的 Python 可调用对象携带比 C 函数指针更多的信息(例如引用计数、任何闭包捕获变量的详细信息等)。

一种可能的替代方法是从cdef class具有适当cdef功能的 a 继承,因此您至少可以在特定情况下获得适当的性能:

# an abstract function pointer class
cdef class FPtr:
    cdef double function(self,double[:] x) except? 0.0:
        # I'm assuming you might want to pass exceptions back to Python - use 0.0 to indicate that there might have been an error
        raise NotImplementedError()

# an example class that inherits from the abstract pointer type    
cdef class SumSq(FPtr):
    cdef double function(self,double[:] x) except? 0.0:
        cdef double sum=0.0
        for i in range(x.shape[0]):
            sum += x[i]**2
        return sum

# an example class that just wraps a Python callable
# this will be no faster, but makes the code generically usable
cdef class PyFPtr(FPtr):
    cdef object f
    def __init__(self,f):
        self.f = f

    cdef double function(self,double[:] x) except? 0.0:
        return self.f(x) # will raise an exception if the types don't match

def example_function(FPtr my_callable):
    import numpy as np
    return my_callable.function(np.ones((10,)))

使用它example_function(SumSq())可以按预期工作(并且具有 Cython 速度);example_function(PyFPtr(lambda x: x[0]))按预期工作(可调用中没有 Cython 速度);example_function(PyFPtr(lambda x: "hello"))按预期给出类型错误。


推荐阅读