numpy数组的共享字典? [英] Shared dict of numpy arrays?

查看:56
本文介绍了numpy数组的共享字典?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想存储一个包含许多 numpy 数组的 dict 并在进程间共享它.

I want to store a dict with many numpy arrays and share it across processes.

import ctypes
import multiprocessing
from typing import Dict, Any

import numpy as np

dict_of_np: Dict[Any, np.ndarray] = multiprocessing.Manager().dict()


def get_numpy(key):
    if key not in dict_of_np:
        shared_array = multiprocessing.Array(ctypes.c_int32, 5)
        shared_np = np.frombuffer(shared_array.get_obj(), dtype=np.int32)
        dict_of_np[key] = shared_np
    return dict_of_np[key]


if __name__ == "__main__":
    a = get_numpy("5")
    a[1] = 5
    print(a)  # prints [0 5 0 0 0]
    b = get_numpy("5")
    print(b)  # prints [0 0 0 0 0]

我按照在共享内存中使用numpy数组进行多处理中的说明进行操作 使用缓冲区创建 numpy 数组,但是当我尝试将生成的 numpy 数组保存在 dict 中时,它不起作用.正如您在上面看到的,当再次使用键访问 dict 时,不会保存对 numpy 数组的更改.

I followed the instructions in Use numpy array in shared memory for multiprocessing to create the numpy arrays using a buffer, but when I try to save the resulting numpy array in a dict, it doesn't work. As you can see above, changes to a numpy array don't get saved when accessing the dict again using the key.

如何共享 numpy 数组的字典?我需要共享字典和数组并使用相同的内存.

How can I share a dict of numpy arrays? I need both the dict and the arrays to be shared and use the same memory.

推荐答案

基于我们来自 this 问题 我可能想出了一个解决方案:通过在主进程中使用线程来处理 multiprocessing.shared_memory.SharedMemory 对象的实例化,您可以确保对共享内存对象仍然存在,底层内存不会过早删除.这仅解决了在没有更多引用时文件被删除的窗口的问题.只要底层的内存视图需要,它并没有解决要求每个打开的实例都被保留的问题.

based on our discussion from this question I may have come up with a solution: By using a thread in the main process to handle the instantiation of multiprocessing.shared_memory.SharedMemory objects, you can ensure a reference to the shared memory object sticks around, and the underlying memory isn't deleted too early. This only solves the problem specifically with windows where the file is deleted when no more references to it exist. It does not solve the problem of requiring each open instance to be held onto as long as the underlying memoryview is needed.

这个经理线程倾听"用于输入 multiprocessing.Queue 上的消息,并创建/返回有关共享内存对象的数据.锁用于确保响应被正确的进程读取(否则响应可能会混淆).

This manager thread "listens" for messages on an input multiprocessing.Queue, and creates / returns data about shared memory objects. A lock is used to make sure the response is read by the correct process (otherwise responses may get mixed up).

所有共享内存对象首先由主进程创建,并保留到显式删除,以便其他进程可以访问它们.

All shared memory objects are first created by the main process, and held onto until explicitly deleted so that other processes may access them.

示例:

import multiprocessing
from multiprocessing import shared_memory, Queue, Process, Lock
from threading import Thread
import numpy as np

class Exit_Flag: pass
 
class SHMController:
    def __init__(self):
        self._shm_objects = {}
        self.mq = Queue() #message input queue
        self.rq = Queue() #response output queue
        self.lock = Lock() #only let one child talk to you at a time
        self._processing_thread = Thread(target=self.process_messages)
    
    def start(self): #to be called after all child processes are started
        self._processing_thread.start()
        
    def stop(self):
        self.mq.put(Exit_Flag())
        
    def __enter__(self):
        self.start()
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stop()
    
    def process_messages(self):
        while True:
            message_obj = self.mq.get()
            if isinstance(message_obj, Exit_Flag):
                break
            elif isinstance(message_obj, str):
                message = message_obj
                response = self.handle_message(message)
                self.rq.put(response)
        self.mq.close()
        self.rq.close()
    
    def handle_message(self, message):
        method, arg = message.split(':', 1)
        if method == "exists":
            if arg in self._shm_objects: #if shm.name exists or not
                return "ok:true"
            else:
                return "ok:false"
        if method == "size":
            if arg in self._shm_objects:
                return f"ok:{len(self._shm_objects[arg].buf)}"
            else:
                return "ko:-1"
        if method == "create":
            args = arg.split(",") #name, size or just size
            if len(args) == 1:
                name = None
                size = int(args[0])
            elif len(args) == 2:
                name = args[0]
                size = int(args[1])
            if name in self._shm_objects:
                return f"ko:'{name}' already created"
            else:
                try:
                    shm = shared_memory.SharedMemory(name=name, create=True, size=size)
                except FileExistsError:
                    return f"ko:'{name}' already exists"
                self._shm_objects[shm.name] = shm
                return f"ok:{shm.name}"
        if method == "destroy":
            if arg in self._shm_objects:
                self._shm_objects[arg].close()
                self._shm_objects[arg].unlink()
                del self._shm_objects[arg]
                return f"ok:'{arg}' destroyed"
            else:
                return f"ko:'{arg}' does not exist"
    
def create(mq, rq, lock):
    #helper functions here could make access less verbose
    with lock:
        mq.put("create:key123,8")
        response = rq.get()
    print(response)
    if response[:2] == "ok":
        name = response.split(':')[1]
        with lock:
            mq.put(f"size:{name}")
            response = rq.get()
        print(response)
        if response[:2] == "ok":
            size = int(response.split(":")[1])
            shm = shared_memory.SharedMemory(name=name, create=False, size=size)
        else:
            print("Oh no....")
            return
    else:
        print("Uh oh....")
        return
    arr = np.ndarray((2,), buffer=shm.buf, dtype=np.int32)
    arr[:] = (1,2)
    print(arr)
    shm.close()
    
def modify(mq, rq, lock):
    while True: #until the shm exists
        with lock:
            mq.put("exists:key123")
            response = rq.get()
        if response == "ok:true":
            print("key:exists")
            break
    with lock:
        mq.put("size:key123")
        response = rq.get()
    print(response)
    if response[:2] == "ok":
        size = int(response.split(":")[1])
        shm = shared_memory.SharedMemory(name="key123", create=False, size=size)
    else:
        print("Oh no....")
        return
    arr = np.ndarray((2,), buffer=shm.buf, dtype=np.int32)
    arr[0] += 5
    print(arr)
    shm.close()
    
def delete(mq, rq, lock):
    pass #TODO make a test for this?

 
if __name__ == "__main__":
    multiprocessing.set_start_method("spawn") #because I'm mixing threads and processes
    with SHMController() as controller:
        mq, rq, lock = controller.mq, controller.rq, controller.lock
        create_task = Process(target=create, args=(mq, rq, lock))
        create_task.start()
        create_task.join()
        modify_task = Process(target=modify, args=(mq, rq, lock))
        modify_task.start()
        modify_task.join()
    print("finished")

为了解决每个 shm 只要数组还活着的问题,您必须保留对 那个 特定 shm 对象的引用.通过将引用作为属性附加到自定义数组子类(从 numpy 子类化指南复制),将引用与数组保持在一起非常简单.

In order to solve the problem of each shm staying alive as long as the array does, you must keep a reference to that specific shm object. Keeping a reference alongside the array is fairly straightforward by attaching it as an attribute to a custom array subclass (copied from the numpy guide to subclassing)

class SHMArray(np.ndarray): #copied from https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array

    def __new__(cls, input_array, shm=None):
        obj = np.asarray(input_array).view(cls)
        obj.shm = shm
        return obj

    def __array_finalize__(self, obj):
        if obj is None: return
        self.shm = getattr(obj, 'shm', None)
#example
shm = shared_memory.SharedMemory(name=name)
np_array = SHMArray(np.ndarray(shape, buffer=shm.buf, dtype=np.int32), shm)

这篇关于numpy数组的共享字典?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆