首页 > 解决方案 > python多处理:使用多处理时丢失了一些类属性

问题描述

我的问题的原始版本

我正在尝试使用scipy.optimize.brute.

如果给定 4 个参数,则可以评估成本函数,但是这 4 个参数必须遵循一些条件。

为了处理它和其他一些复杂的问题,我创建了我的 python 类,Parameter如下例所示,但是当我通过workers关键字使用多处理时,一些属性丢失了。

我的问题的简化版本

import numpy as np
from multiprocessing import Pool

class Parameter(np.ndarray):
    def __new__(cls, maximum):
        self = np.asarray([0., 0., 0., 0.], dtype=np.float64).view(cls)
        return self

    def __init__(self, maximum):
        self.maximum = maximum
        self.validity = True

    def isvalid(self):
        if self.sum() <= self.maximum:
            return True
        else:
            return False

    def set(self, arg):
        for i in range(4):
            self[i] = arg[i]
        self.validity = self.isvalid()

def cost(arg, para):
    para.set(arg)
    if para.validity:
        return para.sum()
    else:
        return para.maximum

class CostWrapper:
    def __init__(self, f, args):
        self.f = f
        self.args = [] if args is None else args

    def __call__(self, x):
        return self.f(np.asarray(x), *self.args)

if __name__ == '__main__':
    parameter = Parameter(100)
    wrapped_cost = CostWrapper(cost, (parameter,))
    parameters_to_be_evaluated = [np.random.rand(4) for _ in range(4)]
    with Pool(2) as p:
        res = p.map(wrapped_cost, parameters_to_be_evaluated)

,这提高了

  File "\_bug_attribute_lose.py", line 126, in isvalid
    if self.sum() <= self.maximum:
AttributeError: 'Parameter' object has no attribute 'maximum'

但是,如果我使用wrapped_costwithout p.map,如下所示不会引发错误。

wrapped_cost(np.random.rand(4))

我试过的

通过在我的代码周围放置一些打印消息,我发现__new____init__方法都只调用一次,所以我猜多处理库以某种方式复制了parameter.

另外,我发现复制的版本parameter仅包含 np.ndarray 具有的属性:

dir(para) = ['T', '__abs__', '__add__', '__and__', '__array__', '__array_finalize__', '__array_function__', '__array_interface__', '__array_prepare__', '__array_priority__', '__array_struct__', '__array_ufunc__', '__array_wrap__', '__bool__', '__class__', '__complex__', '__contains__', '__copy__', '__deepcopy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__divmod__', '__doc__', '__eq__', '__float__', '__floordiv__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__iadd__', '__iand__', '__ifloordiv__', '__ilshift__', '__imatmul__', '__imod__', '__imul__', '__index__', '__init__', '__init_subclass__', '__int__', '__invert__', '__ior__', '__ipow__', '__irshift__', '__isub__', '__iter__', '__itruediv__', '__ixor__', '__le__', '__len__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__module__', '__mul__', '__ne__', '__neg__', '__new__', '__or__', '__pos__', '__pow__', '__radd__', '__rand__', '__rdivmod__', '__reduce__', '__reduce_ex__', '__repr__', '__rfloordiv__', '__rlshift__', '__rmatmul__', '__rmod__', '__rmul__', '__ror__', '__rpow__', '__rrshift__', '__rshift__', '__rsub__', '__rtruediv__', '__rxor__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__str__', '__sub__', '__subclasshook__', '__truediv__', '__xor__', 'all', 'any', 'argmax', 'argmin', 'argpartition', 'argsort', 'astype', 'base', 'byteswap', 'choose', 'clip', 'compress', 'conj', 'conjugate', 'copy', 'ctypes', 'cumprod', 'cumsum', 'data', 'diagonal', 'dot', 'dtype', 'dump', 'dumps', 'fill', 'flags', 'flat', 'flatten', 'getfield', 'imag', 'isvalid', 'item', 'itemset', 'itemsize', 'max', 'mean', 'min', 'nbytes', 'ndim', 'newbyteorder', 'nonzero', 'partition', 'prod', 'ptp', 'put', 'ravel', 'real', 'repeat', 'reshape', 'resize', 'round', 'searchsorted', 'set', 'setfield', 'setflags', 'shape', 'size', 'sort', 'squeeze', 'std', 'strides', 'sum', 'swapaxes', 'take', 'tobytes', 'tofile', 'tolist', 'tostring', 'trace', 'transpose', 'var', 'view']

(请参阅既不存在“最大值”也不存在“有效性”)

因此,我尝试在类中实现__copy__方法,例如Parameter

def __copy__(self):
    print('__copy__')
    new = Parameter(self.maximum)
    new.__dict__.update(self.__dict__)
    return new

,但失败了。

我的问题:

  1. Parameter该对象应该丢失的一些属性。我的猜测是因为multiprocessing库以某种方式复制了变量parameter,但我没有正确实现复制方法。我对吗?

  2. 如果是这样,我该怎么做?如果没有,请告诉我是哪个导致错误。

标签: pythonmultiprocessing

解决方案


这有点棘手,但这是可能的。

首先,当从您继承时,np.ndarray应该定义__array_finalize____new__. 请注意,__array_finalize__由于某种原因被多次调用,因此您必须引入一个空保护。更多关于这个的文档

def __array_finalize__(self, obj):
    if obj is None: return
    self.maximum = getattr(obj, 'maximum', None)
    self.validity = getattr(obj, 'validity', None)

其次,在使用picklemultiprocessing.Pool将数据发送给工作人员之前对数据进行序列化。在此过程中,您的额外属性将丢失。所以我们必须在继续之前将它们添加回来。

覆盖__reduce__方法:

def __reduce__(self):
    pickled_state = super().__reduce__()
    new_state = pickled_state[2] + (self.__dict__, )
    return (*pickled_state[0:2], new_state)

并覆盖__setstate__方法:

def __setstate__(self, state):
        self.__dict__.update(state[-1])
        super().__setstate__(state[0:-1])

该实现是从这个答案中借来的。

好的,现在让我们将它组合成一个可运行的代码:

import numpy as np
from multiprocessing import Pool

class Parameter(np.ndarray):
    def __new__(cls, maximum):
        obj = np.asarray([0, 0, 0, 0], dtype=np.float64).view(cls)
        obj.maximum = maximum
        obj.validity = True
        return obj
    
    def __array_finalize__(self, obj):
        if obj is None: return
        self.maximum = getattr(obj, 'maximum', None)
        self.validity = getattr(obj, 'validity', None)

    def __reduce__(self):
        pickled_state = super().__reduce__()
        new_state = pickled_state[2] + (self.__dict__, )
        return (*pickled_state[0:2], new_state)
    
    def __setstate__(self, state):
        self.__dict__.update(state[-1])
        super().__setstate__(state[0:-1])

    def isvalid(self):
        return self.sum() <= self.maximum

    def set(self, arg):
        for i in range(4):
            self[i] = arg[i]
        self.validity = self.isvalid()

def cost(arg, para):
    para.set(arg)
    return para.sum() if para.validity else para.maximum

class CostWrapper:
    def __init__(self, f, args):
        self.f = f
        self.args = () if args is None else args

    def __call__(self, x):
        return self.f(np.asarray(x), *self.args)

if __name__ == '__main__':
    parameter = Parameter(100)
    wrapped_cost = CostWrapper(cost, (parameter,))
    parameters_to_be_evaluated = [np.random.rand(4) for _ in range(4)]
    with Pool(2) as p:
        res = p.map(wrapped_cost, parameters_to_be_evaluated)

顺便问一下,你知道这个问题已经存在了吗?在这里。但它不会与多个属性共享您的问题(这是一个简单的解决方案),所以这次我会让您放松一些。


推荐阅读