python - numpy数组的共享字典?
问题描述
我想用许多 numpy 数组存储一个 dict 并在进程之间共享它。
import ctypes
import multiprocessing
from typing import Dict, Any
import numpy as np
dict_of_np: Dict[Any, np.ndarray] = multiprocessing.Manager().dict()
def get_numpy(key):
if key not in dict_of_np:
shared_array = multiprocessing.Array(ctypes.c_int32, 5)
shared_np = np.frombuffer(shared_array.get_obj(), dtype=np.int32)
dict_of_np[key] = shared_np
return dict_of_np[key]
if __name__ == "__main__":
a = get_numpy("5")
a[1] = 5
print(a) # prints [0 5 0 0 0]
b = get_numpy("5")
print(b) # prints [0 0 0 0 0]
我按照在共享内存中使用 numpy 数组进行多处理中的说明使用缓冲区创建 numpy 数组,但是当我尝试将生成的 numpy 数组保存在 dict 中时,它不起作用。正如您在上面看到的,当再次使用该键访问 dict 时,不会保存对 numpy 数组的更改。
如何共享 numpy 数组的字典?我需要共享字典和数组并使用相同的内存。
解决方案
根据我们对这个问题的讨论,我可能想出了一个解决方案:通过在主进程中使用线程来处理multiprocessing.shared_memory.SharedMemory
对象的实例化,您可以确保对共享内存对象的引用保持不变,并且底层内存是' t 删除过早。这仅解决了在不再存在对文件的引用时文件被删除的窗口的问题。它不能解决只要需要底层内存视图就要求保留每个打开的实例的问题。
此管理器线程“侦听”输入消息multiprocessing.Queue
,并创建/返回有关共享内存对象的数据。锁用于确保响应被正确的进程读取(否则响应可能会混淆)。
所有共享内存对象首先由主进程创建,并保留到显式删除,以便其他进程可以访问它们。
例子:
import multiprocessing
from multiprocessing import shared_memory, Queue, Process, Lock
from threading import Thread
import numpy as np
class Exit_Flag: pass
class SHMController:
def __init__(self):
self._shm_objects = {}
self.mq = Queue() #message input queue
self.rq = Queue() #response output queue
self.lock = Lock() #only let one child talk to you at a time
self._processing_thread = Thread(target=self.process_messages)
def start(self): #to be called after all child processes are started
self._processing_thread.start()
def stop(self):
self.mq.put(Exit_Flag())
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
def process_messages(self):
while True:
message_obj = self.mq.get()
if isinstance(message_obj, Exit_Flag):
break
elif isinstance(message_obj, str):
message = message_obj
response = self.handle_message(message)
self.rq.put(response)
self.mq.close()
self.rq.close()
def handle_message(self, message):
method, arg = message.split(':', 1)
if method == "exists":
if arg in self._shm_objects: #if shm.name exists or not
return "ok:true"
else:
return "ok:false"
if method == "size":
if arg in self._shm_objects:
return f"ok:{len(self._shm_objects[arg].buf)}"
else:
return "ko:-1"
if method == "create":
args = arg.split(",") #name, size or just size
if len(args) == 1:
name = None
size = int(args[0])
elif len(args) == 2:
name = args[0]
size = int(args[1])
if name in self._shm_objects:
return f"ko:'{name}' already created"
else:
try:
shm = shared_memory.SharedMemory(name=name, create=True, size=size)
except FileExistsError:
return f"ko:'{name}' already exists"
self._shm_objects[shm.name] = shm
return f"ok:{shm.name}"
if method == "destroy":
if arg in self._shm_objects:
self._shm_objects[arg].close()
self._shm_objects[arg].unlink()
del self._shm_objects[arg]
return f"ok:'{arg}' destroyed"
else:
return f"ko:'{arg}' does not exist"
def create(mq, rq, lock):
#helper functions here could make access less verbose
with lock:
mq.put("create:key123,8")
response = rq.get()
print(response)
if response[:2] == "ok":
name = response.split(':')[1]
with lock:
mq.put(f"size:{name}")
response = rq.get()
print(response)
if response[:2] == "ok":
size = int(response.split(":")[1])
shm = shared_memory.SharedMemory(name=name, create=False, size=size)
else:
print("Oh no....")
return
else:
print("Uh oh....")
return
arr = np.ndarray((2,), buffer=shm.buf, dtype=np.int32)
arr[:] = (1,2)
print(arr)
shm.close()
def modify(mq, rq, lock):
while True: #until the shm exists
with lock:
mq.put("exists:key123")
response = rq.get()
if response == "ok:true":
print("key:exists")
break
with lock:
mq.put("size:key123")
response = rq.get()
print(response)
if response[:2] == "ok":
size = int(response.split(":")[1])
shm = shared_memory.SharedMemory(name="key123", create=False, size=size)
else:
print("Oh no....")
return
arr = np.ndarray((2,), buffer=shm.buf, dtype=np.int32)
arr[0] += 5
print(arr)
shm.close()
def delete(mq, rq, lock):
pass #TODO make a test for this?
if __name__ == "__main__":
multiprocessing.set_start_method("spawn") #because I'm mixing threads and processes
with SHMController() as controller:
mq, rq, lock = controller.mq, controller.rq, controller.lock
create_task = Process(target=create, args=(mq, rq, lock))
create_task.start()
create_task.join()
modify_task = Process(target=modify, args=(mq, rq, lock))
modify_task.start()
modify_task.join()
print("finished")
为了解决每个 shm 保持与数组一样长的问题,您必须保留对该特定 shm 对象的引用。通过将引用作为属性附加到自定义数组子类(从 numpy 子类化指南复制),在数组旁边保留引用非常简单
class SHMArray(np.ndarray): #copied from https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array
def __new__(cls, input_array, shm=None):
obj = np.asarray(input_array).view(cls)
obj.shm = shm
return obj
def __array_finalize__(self, obj):
if obj is None: return
self.shm = getattr(obj, 'shm', None)
#example
shm = shared_memory.SharedMemory(name=name)
np_array = SHMArray(np.ndarray(shape, buffer=shm.buf, dtype=np.int32), shm)
推荐阅读
- c++ - 字符串比较中的意外输出
- javascript - 自定义密码要求的正则表达式
- listview - 如何在flutter的页面中添加listview和tabbarview
- angular - 我无法在角材料垫选择中填充保存的百分比值
- mysql - 对话属于多个用户,但用户 A 删除而用户 B 不删除。我们如何防止它被退回?
- java - SparkSql Aerospike Java 连接器
- java - 如何在 Spring Boot 中从资源服务器中的令牌中提取声明
- html - 使用关键帧动画从左到右动画背景填充
- python - 仅在括号之间不使用逗号分割字符串
- cron - Cronjob 检查共享主机上修改的文件