Coordinates shared resource usage across nodes.
| 25 | |
| 26 | |
| 27 | class ResourceManager: |
| 28 | """Coordinates shared resource usage across nodes.""" |
| 29 | |
| 30 | def __init__(self, log_manager: LogManager | None = None): |
| 31 | self.log_manager = log_manager |
| 32 | self._lock = threading.Lock() |
| 33 | self._resources: Dict[str, _ResourceSlot] = {} |
| 34 | |
| 35 | @contextmanager |
| 36 | def guard_node(self, node: Node): |
| 37 | """Acquire all resources required by the given node.""" |
| 38 | requests = self._resolve_node_requests(node) |
| 39 | with self._acquire_resources(requests): |
| 40 | yield |
| 41 | |
| 42 | def _resolve_node_requests(self, node: Node) -> List[ResourceRequest]: |
| 43 | registration = get_node_registration(node.node_type) |
| 44 | caps = registration.capabilities |
| 45 | requests: List[ResourceRequest] = [] |
| 46 | key = caps.resource_key |
| 47 | limit = caps.resource_limit |
| 48 | if key and limit and limit > 0: |
| 49 | requests.append(ResourceRequest(key=key, limit=limit)) |
| 50 | return requests |
| 51 | |
| 52 | @contextmanager |
| 53 | def _acquire_resources(self, requests: Iterable[ResourceRequest]): |
| 54 | acquired: List[Tuple[str, threading.Semaphore]] = [] |
| 55 | try: |
| 56 | for request in sorted(requests, key=lambda item: item.key): |
| 57 | semaphore = self._get_or_create_resource(request) |
| 58 | self._log_debug(f"Acquiring resource {request.key}") |
| 59 | semaphore.acquire() |
| 60 | acquired.append((request.key, semaphore)) |
| 61 | yield |
| 62 | finally: |
| 63 | for key, semaphore in reversed(acquired): |
| 64 | semaphore.release() |
| 65 | self._log_debug(f"Released resource {key}") |
| 66 | |
| 67 | def _get_or_create_resource(self, request: ResourceRequest) -> threading.Semaphore: |
| 68 | with self._lock: |
| 69 | slot = self._resources.get(request.key) |
| 70 | if slot and slot.limit != request.limit: |
| 71 | slot = None |
| 72 | if not slot: |
| 73 | slot = _ResourceSlot( |
| 74 | semaphore=threading.Semaphore(request.limit), |
| 75 | limit=request.limit, |
| 76 | ) |
| 77 | self._resources[request.key] = slot |
| 78 | return slot.semaphore |
| 79 | |
| 80 | def _log_debug(self, message: str) -> None: |
| 81 | if self.log_manager: |
| 82 | self.log_manager.debug(message) |