| 16 | |
| 17 | |
| 18 | class ExceptionMiddleware: |
| 19 | def __init__( |
| 20 | self, |
| 21 | app: ASGIApp, |
| 22 | handlers: Mapping[Any, ExceptionHandler] | None = None, |
| 23 | debug: bool = False, |
| 24 | ) -> None: |
| 25 | self.app = app |
| 26 | self.debug = debug # TODO: We ought to handle 404 cases if debug is set. |
| 27 | self._status_handlers: StatusHandlers = {} |
| 28 | self._exception_handlers: ExceptionHandlers = { |
| 29 | HTTPException: self.http_exception, |
| 30 | WebSocketException: self.websocket_exception, |
| 31 | } |
| 32 | if handlers is not None: # pragma: no branch |
| 33 | for key, value in handlers.items(): |
| 34 | self.add_exception_handler(key, value) |
| 35 | |
| 36 | def add_exception_handler( |
| 37 | self, |
| 38 | exc_class_or_status_code: int | type[Exception], |
| 39 | handler: ExceptionHandler, |
| 40 | ) -> None: |
| 41 | if isinstance(exc_class_or_status_code, int): |
| 42 | self._status_handlers[exc_class_or_status_code] = handler |
| 43 | else: |
| 44 | assert issubclass(exc_class_or_status_code, Exception) |
| 45 | self._exception_handlers[exc_class_or_status_code] = handler |
| 46 | |
| 47 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
| 48 | if scope["type"] not in ("http", "websocket"): |
| 49 | await self.app(scope, receive, send) |
| 50 | return |
| 51 | |
| 52 | scope["starlette.exception_handlers"] = ( |
| 53 | self._exception_handlers, |
| 54 | self._status_handlers, |
| 55 | ) |
| 56 | |
| 57 | conn: Request | WebSocket |
| 58 | if scope["type"] == "http": |
| 59 | conn = Request(scope, receive, send) |
| 60 | else: |
| 61 | conn = WebSocket(scope, receive, send) |
| 62 | |
| 63 | await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send) |
| 64 | |
| 65 | async def http_exception(self, request: Request, exc: Exception) -> Response: |
| 66 | assert isinstance(exc, HTTPException) |
| 67 | if exc.status_code in {204, 304}: |
| 68 | return Response(status_code=exc.status_code, headers=exc.headers) |
| 69 | return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers) |
| 70 | |
| 71 | async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None: |
| 72 | assert isinstance(exc, WebSocketException) |
| 73 | await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover |
no outgoing calls