首页 > 解决方案 > numpy数组的共享字典?

问题描述

我想用许多 numpy 数组存储一个 dict 并在进程之间共享它。

import ctypes
import multiprocessing
from typing import Dict, Any

import numpy as np

dict_of_np: Dict[Any, np.ndarray] = multiprocessing.Manager().dict()


def get_numpy(key):
    if key not in dict_of_np:
        shared_array = multiprocessing.Array(ctypes.c_int32, 5)
        shared_np = np.frombuffer(shared_array.get_obj(), dtype=np.int32)
        dict_of_np[key] = shared_np
    return dict_of_np[key]


if __name__ == "__main__":
    a = get_numpy("5")
    a[1] = 5
    print(a)  # prints [0 5 0 0 0]
    b = get_numpy("5")
    print(b)  # prints [0 0 0 0 0]

我按照在共享内存中使用 numpy 数组进行多处理中的说明使用缓冲区创建 numpy 数组,但是当我尝试将生成的 numpy 数组保存在 dict 中时,它不起作用。正如您在上面看到的,当再次使用该键访问 dict 时,不会保存对 numpy 数组的更改。

如何共享 numpy 数组的字典?我需要共享字典和数组并使用相同的内存。

标签: pythonnumpypython-multiprocessing

解决方案


根据我们对这个问题的讨论,我可能想出了一个解决方案:通过在主进程中使用线程来处理multiprocessing.shared_memory.SharedMemory对象的实例化,您可以确保对共享内存对象的引用保持不变,并且底层内存是' t 删除过早。这仅解决了在不再存在对文件的引用时文件被删除的窗口的问题。它不能解决只要需要底层内存视图就要求保留每个打开的实例的问题。

此管理器线程“侦听”输入消息multiprocessing.Queue,并创建/返回有关共享内存对象的数据。锁用于确保响应被正确的进程读取(否则响应可能会混淆)。

所有共享内存对象首先由主进程创建,并保留到显式删除,以便其他进程可以访问它们。

例子:

import multiprocessing
from multiprocessing import shared_memory, Queue, Process, Lock
from threading import Thread
import numpy as np

class Exit_Flag: pass
 
class SHMController:
    def __init__(self):
        self._shm_objects = {}
        self.mq = Queue() #message input queue
        self.rq = Queue() #response output queue
        self.lock = Lock() #only let one child talk to you at a time
        self._processing_thread = Thread(target=self.process_messages)
    
    def start(self): #to be called after all child processes are started
        self._processing_thread.start()
        
    def stop(self):
        self.mq.put(Exit_Flag())
        
    def __enter__(self):
        self.start()
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stop()
    
    def process_messages(self):
        while True:
            message_obj = self.mq.get()
            if isinstance(message_obj, Exit_Flag):
                break
            elif isinstance(message_obj, str):
                message = message_obj
                response = self.handle_message(message)
                self.rq.put(response)
        self.mq.close()
        self.rq.close()
    
    def handle_message(self, message):
        method, arg = message.split(':', 1)
        if method == "exists":
            if arg in self._shm_objects: #if shm.name exists or not
                return "ok:true"
            else:
                return "ok:false"
        if method == "size":
            if arg in self._shm_objects:
                return f"ok:{len(self._shm_objects[arg].buf)}"
            else:
                return "ko:-1"
        if method == "create":
            args = arg.split(",") #name, size or just size
            if len(args) == 1:
                name = None
                size = int(args[0])
            elif len(args) == 2:
                name = args[0]
                size = int(args[1])
            if name in self._shm_objects:
                return f"ko:'{name}' already created"
            else:
                try:
                    shm = shared_memory.SharedMemory(name=name, create=True, size=size)
                except FileExistsError:
                    return f"ko:'{name}' already exists"
                self._shm_objects[shm.name] = shm
                return f"ok:{shm.name}"
        if method == "destroy":
            if arg in self._shm_objects:
                self._shm_objects[arg].close()
                self._shm_objects[arg].unlink()
                del self._shm_objects[arg]
                return f"ok:'{arg}' destroyed"
            else:
                return f"ko:'{arg}' does not exist"
    
def create(mq, rq, lock):
    #helper functions here could make access less verbose
    with lock:
        mq.put("create:key123,8")
        response = rq.get()
    print(response)
    if response[:2] == "ok":
        name = response.split(':')[1]
        with lock:
            mq.put(f"size:{name}")
            response = rq.get()
        print(response)
        if response[:2] == "ok":
            size = int(response.split(":")[1])
            shm = shared_memory.SharedMemory(name=name, create=False, size=size)
        else:
            print("Oh no....")
            return
    else:
        print("Uh oh....")
        return
    arr = np.ndarray((2,), buffer=shm.buf, dtype=np.int32)
    arr[:] = (1,2)
    print(arr)
    shm.close()
    
def modify(mq, rq, lock):
    while True: #until the shm exists
        with lock:
            mq.put("exists:key123")
            response = rq.get()
        if response == "ok:true":
            print("key:exists")
            break
    with lock:
        mq.put("size:key123")
        response = rq.get()
    print(response)
    if response[:2] == "ok":
        size = int(response.split(":")[1])
        shm = shared_memory.SharedMemory(name="key123", create=False, size=size)
    else:
        print("Oh no....")
        return
    arr = np.ndarray((2,), buffer=shm.buf, dtype=np.int32)
    arr[0] += 5
    print(arr)
    shm.close()
    
def delete(mq, rq, lock):
    pass #TODO make a test for this?

 
if __name__ == "__main__":
    multiprocessing.set_start_method("spawn") #because I'm mixing threads and processes
    with SHMController() as controller:
        mq, rq, lock = controller.mq, controller.rq, controller.lock
        create_task = Process(target=create, args=(mq, rq, lock))
        create_task.start()
        create_task.join()
        modify_task = Process(target=modify, args=(mq, rq, lock))
        modify_task.start()
        modify_task.join()
    print("finished")

为了解决每个 shm 保持与数组一样长的问题,您必须保留对该特定 shm 对象的引用。通过将引用作为属性附加到自定义数组子类(从 numpy 子类化指南复制),在数组旁边保留引用非常简单

class SHMArray(np.ndarray): #copied from https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array

    def __new__(cls, input_array, shm=None):
        obj = np.asarray(input_array).view(cls)
        obj.shm = shm
        return obj

    def __array_finalize__(self, obj):
        if obj is None: return
        self.shm = getattr(obj, 'shm', None)
#example
shm = shared_memory.SharedMemory(name=name)
np_array = SHMArray(np.ndarray(shape, buffer=shm.buf, dtype=np.int32), shm)

推荐阅读