| 220 | |
| 221 | |
| 222 | class StreamingResponse(Response): |
| 223 | body_iterator: AsyncContentStream |
| 224 | |
| 225 | def __init__( |
| 226 | self, |
| 227 | content: ContentStream, |
| 228 | status_code: int = 200, |
| 229 | headers: Mapping[str, str] | None = None, |
| 230 | media_type: str | None = None, |
| 231 | background: BackgroundTask | None = None, |
| 232 | ) -> None: |
| 233 | if isinstance(content, AsyncIterable): |
| 234 | self.body_iterator = content |
| 235 | else: |
| 236 | self.body_iterator = iterate_in_threadpool(content) |
| 237 | self.status_code = status_code |
| 238 | self.media_type = self.media_type if media_type is None else media_type |
| 239 | self.background = background |
| 240 | self.init_headers(headers) |
| 241 | |
| 242 | async def listen_for_disconnect(self, receive: Receive) -> None: |
| 243 | while True: |
| 244 | message = await receive() |
| 245 | if message[class="st">"type"] == class="st">"http.disconnect": |
| 246 | break |
| 247 | |
| 248 | async def stream_response(self, send: Send) -> None: |
| 249 | await send({class="st">"type": class="st">"http.response.start", class="st">"status": self.status_code, class="st">"headers": self.raw_headers}) |
| 250 | async for chunk in self.body_iterator: |
| 251 | if not isinstance(chunk, bytes | memoryview): |
| 252 | chunk = chunk.encode(self.charset) |
| 253 | await send({class="st">"type": class="st">"http.response.body", class="st">"body": chunk, class="st">"more_body": True}) |
| 254 | |
| 255 | await send({class="st">"type": class="st">"http.response.body", class="st">"body": bclass="st">"", class="st">"more_body": False}) |
| 256 | |
| 257 | async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
| 258 | if scope[class="st">"type"] == class="st">"websocket": |
| 259 | send = self._wrap_websocket_denial_send(send) |
| 260 | await self.stream_response(send) |
| 261 | if self.background is not None: |
| 262 | await self.background() |
| 263 | return |
| 264 | |
| 265 | spec_version = tuple(map(int, scope.get(class="st">"asgi", {}).get(class="st">"spec_version", class="st">"2.0").split(class="st">"."))) |
| 266 | |
| 267 | if spec_version >= (2, 4): |
| 268 | try: |
| 269 | await self.stream_response(send) |
| 270 | except OSError: |
| 271 | raise ClientDisconnect() |
| 272 | else: |
| 273 | async with create_collapsing_task_group() as task_group: |
| 274 | |
| 275 | async def wrap(func: Callable[[], Awaitable[None]]) -> None: |
| 276 | await func() |
| 277 | task_group.cancel_scope.cancel() |
| 278 | |
| 279 | task_group.start_soon(wrap, partial(self.stream_response, send)) |
no outgoing calls