| 24 | |
| 25 | |
| 26 | class WebSocket(HTTPConnection[StateT]): |
| 27 | def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: |
| 28 | super().__init__(scope) |
| 29 | assert scope["type"] == "websocket" |
| 30 | self._receive = receive |
| 31 | self._send = send |
| 32 | self.client_state = WebSocketState.CONNECTING |
| 33 | self.application_state = WebSocketState.CONNECTING |
| 34 | |
| 35 | async def receive(self) -> Message: |
| 36 | """ |
| 37 | Receive ASGI websocket messages, ensuring valid state transitions. |
| 38 | """ |
| 39 | if self.client_state == WebSocketState.CONNECTING: |
| 40 | message = await self._receive() |
| 41 | message_type = message["type"] |
| 42 | if message_type != "websocket.connect": |
| 43 | raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}') |
| 44 | self.client_state = WebSocketState.CONNECTED |
| 45 | return message |
| 46 | elif self.client_state == WebSocketState.CONNECTED: |
| 47 | message = await self._receive() |
| 48 | message_type = message["type"] |
| 49 | if message_type not in {"websocket.receive", "websocket.disconnect"}: |
| 50 | raise RuntimeError( |
| 51 | f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}' |
| 52 | ) |
| 53 | if message_type == "websocket.disconnect": |
| 54 | self.client_state = WebSocketState.DISCONNECTED |
| 55 | return message |
| 56 | else: |
| 57 | raise RuntimeError('Cannot call "receive" once a disconnect message has been received.') |
| 58 | |
| 59 | async def send(self, message: Message) -> None: |
| 60 | """ |
| 61 | Send ASGI websocket messages, ensuring valid state transitions. |
| 62 | """ |
| 63 | if self.application_state == WebSocketState.CONNECTING: |
| 64 | message_type = message["type"] |
| 65 | if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}: |
| 66 | raise RuntimeError( |
| 67 | 'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", ' |
| 68 | f"but got {message_type!r}" |
| 69 | ) |
| 70 | if message_type == "websocket.close": |
| 71 | self.application_state = WebSocketState.DISCONNECTED |
| 72 | elif message_type == "websocket.http.response.start": |
| 73 | self.application_state = WebSocketState.RESPONSE |
| 74 | else: |
| 75 | self.application_state = WebSocketState.CONNECTED |
| 76 | await self._send(message) |
| 77 | elif self.application_state == WebSocketState.CONNECTED: |
| 78 | message_type = message["type"] |
| 79 | if message_type not in {"websocket.send", "websocket.close"}: |
| 80 | raise RuntimeError( |
| 81 | f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}' |
| 82 | ) |
| 83 | if message_type == "websocket.close": |
no outgoing calls