|
| 1 | +""" |
| 2 | +A context manager for parallel and distributed processing using |
| 3 | +multiprocessing.Manager to share state across processes. |
| 4 | +""" |
| 5 | + |
| 6 | +from collections.abc import Callable |
| 7 | +from collections.abc import Iterator |
| 8 | +import multiprocessing as mp |
| 9 | +from multiprocessing.managers import DictProxy |
| 10 | +import threading |
| 11 | +from threading import Lock |
| 12 | +from typing import Any |
| 13 | +from typing import TypeVar |
| 14 | + |
| 15 | +from laygo.context.types import IContextHandle |
| 16 | +from laygo.context.types import IContextManager |
| 17 | + |
| 18 | +R = TypeVar("R") |
| 19 | + |
| 20 | + |
| 21 | +class ParallelContextHandle(IContextHandle): |
| 22 | + """ |
| 23 | + A lightweight, picklable handle that carries the actual shared objects |
| 24 | + (the DictProxy and Lock) to worker processes. |
| 25 | + """ |
| 26 | + |
| 27 | + def __init__(self, shared_dict: DictProxy, lock: Lock): |
| 28 | + self._shared_dict = shared_dict |
| 29 | + self._lock = lock |
| 30 | + |
| 31 | + def create_proxy(self) -> "IContextManager": |
| 32 | + """ |
| 33 | + Creates a new ParallelContextManager instance that wraps the shared |
| 34 | + objects received by the worker process. |
| 35 | + """ |
| 36 | + return ParallelContextManager(handle=self) |
| 37 | + |
| 38 | + |
| 39 | +class ParallelContextManager(IContextManager): |
| 40 | + """ |
| 41 | + A context manager that enables state sharing across processes. |
| 42 | +
|
| 43 | + It operates in two modes: |
| 44 | + 1. Main Mode: When created normally, it starts a multiprocessing.Manager |
| 45 | + and creates a shared dictionary and lock. |
| 46 | + 2. Proxy Mode: When created from a handle, it wraps a DictProxy and Lock |
| 47 | + that were received from another process. It does not own the manager. |
| 48 | + """ |
| 49 | + |
| 50 | + def __init__(self, initial_context: dict[str, Any] | None = None, handle: ParallelContextHandle | None = None): |
| 51 | + """ |
| 52 | + Initializes the manager. If a handle is provided, it initializes in |
| 53 | + proxy mode; otherwise, it starts a new manager. |
| 54 | + """ |
| 55 | + if handle: |
| 56 | + # --- PROXY MODE INITIALIZATION --- |
| 57 | + # This instance is a client wrapping objects from an existing server. |
| 58 | + self._manager = None # Proxies do not own the manager process. |
| 59 | + self._shared_dict = handle._shared_dict |
| 60 | + self._lock = handle._lock |
| 61 | + else: |
| 62 | + # --- MAIN MODE INITIALIZATION --- |
| 63 | + # This instance owns the manager and its shared objects. |
| 64 | + self._manager = mp.Manager() |
| 65 | + self._shared_dict = self._manager.dict(initial_context or {}) |
| 66 | + self._lock = self._manager.Lock() |
| 67 | + |
| 68 | + # Thread-local storage for lock state to handle concurrent access |
| 69 | + self._local = threading.local() |
| 70 | + |
| 71 | + def _lock_context(self) -> None: |
| 72 | + """Acquire the lock for this context manager.""" |
| 73 | + if not getattr(self._local, "is_locked", False): |
| 74 | + self._lock.acquire() |
| 75 | + self._local.is_locked = True |
| 76 | + |
| 77 | + def _unlock_context(self) -> None: |
| 78 | + """Release the lock for this context manager.""" |
| 79 | + if getattr(self._local, "is_locked", False): |
| 80 | + self._lock.release() |
| 81 | + self._local.is_locked = False |
| 82 | + |
| 83 | + def _execute_locked(self, operation: Callable[[], R]) -> R: |
| 84 | + """A private helper to execute an operation within a lock.""" |
| 85 | + if not getattr(self._local, "is_locked", False): |
| 86 | + self._lock_context() |
| 87 | + try: |
| 88 | + return operation() |
| 89 | + finally: |
| 90 | + self._unlock_context() |
| 91 | + else: |
| 92 | + return operation() |
| 93 | + |
| 94 | + def get_handle(self) -> ParallelContextHandle: |
| 95 | + """ |
| 96 | + Returns a picklable handle containing the shared dict and lock. |
| 97 | + Only the main instance can generate handles. |
| 98 | + """ |
| 99 | + if not self._manager: |
| 100 | + raise TypeError("Cannot get a handle from a proxy context instance.") |
| 101 | + |
| 102 | + return ParallelContextHandle(self._shared_dict, self._lock) |
| 103 | + |
| 104 | + def shutdown(self) -> None: |
| 105 | + """ |
| 106 | + Shuts down the background manager process. |
| 107 | + This is a no-op for proxy instances. |
| 108 | + """ |
| 109 | + if self._manager: |
| 110 | + self._manager.shutdown() |
| 111 | + |
| 112 | + def __enter__(self) -> "ParallelContextManager": |
| 113 | + """Acquires the lock for use in a 'with' statement.""" |
| 114 | + self._lock_context() |
| 115 | + return self |
| 116 | + |
| 117 | + def __exit__(self, exc_type, exc_val, exc_tb) -> None: |
| 118 | + """Releases the lock.""" |
| 119 | + self._unlock_context() |
| 120 | + |
| 121 | + def __getitem__(self, key: str) -> Any: |
| 122 | + return self._shared_dict[key] |
| 123 | + |
| 124 | + def __setitem__(self, key: str, value: Any) -> None: |
| 125 | + self._execute_locked(lambda: self._shared_dict.__setitem__(key, value)) |
| 126 | + |
| 127 | + def __delitem__(self, key: str) -> None: |
| 128 | + self._execute_locked(lambda: self._shared_dict.__delitem__(key)) |
| 129 | + |
| 130 | + def __iter__(self) -> Iterator[str]: |
| 131 | + # Iteration needs to copy the keys to be safe across processes |
| 132 | + return self._execute_locked(lambda: iter(list(self._shared_dict.keys()))) |
| 133 | + |
| 134 | + def __len__(self) -> int: |
| 135 | + return self._execute_locked(lambda: len(self._shared_dict)) |
| 136 | + |
| 137 | + def to_dict(self) -> dict[str, Any]: |
| 138 | + return self._execute_locked(lambda: dict(self._shared_dict)) |
0 commit comments