| 100 | |
| 101 | |
| 102 | class WebSocketTestSession: |
| 103 | def __init__( |
| 104 | self, |
| 105 | app: ASGI3App, |
| 106 | scope: Scope, |
| 107 | portal_factory: _PortalFactoryType, |
| 108 | ) -> None: |
| 109 | self.app = app |
| 110 | self.scope = scope |
| 111 | self.accepted_subprotocol = None |
| 112 | self.portal_factory = portal_factory |
| 113 | self.extra_headers = None |
| 114 | |
| 115 | def __enter__(self) -> Self: |
| 116 | with contextlib.ExitStack() as stack: |
| 117 | self.portal = portal = stack.enter_context(self.portal_factory()) |
| 118 | fut, cs = portal.start_task(self._run) |
| 119 | stack.callback(fut.result) |
| 120 | stack.callback(portal.call, cs.cancel) |
| 121 | self.send({class="st">"type": class="st">"websocket.connect"}) |
| 122 | message = self.receive() |
| 123 | self._raise_on_close(message) |
| 124 | self.accepted_subprotocol = message.get(class="st">"subprotocol", None) |
| 125 | self.extra_headers = message.get(class="st">"headers", None) |
| 126 | stack.callback(self.close, 1000) |
| 127 | self.exit_stack = stack.pop_all() |
| 128 | return self |
| 129 | |
| 130 | def __exit__(self, *args: Any) -> bool | None: |
| 131 | return self.exit_stack.__exit__(*args) |
| 132 | |
| 133 | async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None: |
| 134 | class="st">""" |
| 135 | The sub-thread in which the websocket session runs. |
| 136 | class="st">""" |
| 137 | send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) |
| 138 | send_tx, send_rx = send |
| 139 | receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) |
| 140 | receive_tx, receive_rx = receive |
| 141 | with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs: |
| 142 | self._receive_tx = receive_tx |
| 143 | self._send_rx = send_rx |
| 144 | task_status.started(cs) |
| 145 | await self.app(self.scope, receive_rx.receive, send_tx.send) |
| 146 | |
| 147 | class="cm"># wait for cs.cancel to be called before closing streams |
| 148 | await anyio.sleep_forever() |
| 149 | |
| 150 | def _raise_on_close(self, message: Message) -> None: |
| 151 | if message[class="st">"type"] == class="st">"websocket.close": |
| 152 | raise WebSocketDisconnect(code=message.get(class="st">"code", 1000), reason=message.get(class="st">"reason", class="st">"")) |
| 153 | elif message[class="st">"type"] == class="st">"websocket.http.response.start": |
| 154 | status_code: int = message[class="st">"status"] |
| 155 | headers: list[tuple[bytes, bytes]] = message[class="st">"headers"] |
| 156 | body: list[bytes] = [] |
| 157 | while True: |
| 158 | message = self.receive() |
| 159 | assert message[class="st">"type"] == class="st">"websocket.http.response.body" |