| 281 | |
| 282 | |
| 283 | class SSEDecoder: |
| 284 | _data: list[str] |
| 285 | _event: str | None |
| 286 | _retry: int | None |
| 287 | _last_event_id: str | None |
| 288 | |
| 289 | def __init__(self) -> None: |
| 290 | self._event = None |
| 291 | self._data = [] |
| 292 | self._last_event_id = None |
| 293 | self._retry = None |
| 294 | |
| 295 | def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: |
| 296 | """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" |
| 297 | for chunk in self._iter_chunks(iterator): |
| 298 | # Split before decoding so splitlines() only uses \r and \n |
| 299 | for raw_line in chunk.splitlines(): |
| 300 | line = raw_line.decode("utf-8") |
| 301 | sse = self.decode(line) |
| 302 | if sse: |
| 303 | yield sse |
| 304 | |
| 305 | def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]: |
| 306 | """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks""" |
| 307 | data = b"" |
| 308 | for chunk in iterator: |
| 309 | for line in chunk.splitlines(keepends=True): |
| 310 | data += line |
| 311 | if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): |
| 312 | yield data |
| 313 | data = b"" |
| 314 | if data: |
| 315 | yield data |
| 316 | |
| 317 | async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]: |
| 318 | """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" |
| 319 | async for chunk in self._aiter_chunks(iterator): |
| 320 | # Split before decoding so splitlines() only uses \r and \n |
| 321 | for raw_line in chunk.splitlines(): |
| 322 | line = raw_line.decode("utf-8") |
| 323 | sse = self.decode(line) |
| 324 | if sse: |
| 325 | yield sse |
| 326 | |
| 327 | async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]: |
| 328 | """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks""" |
| 329 | data = b"" |
| 330 | async for chunk in iterator: |
| 331 | for line in chunk.splitlines(keepends=True): |
| 332 | data += line |
| 333 | if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): |
| 334 | yield data |
| 335 | data = b"" |
| 336 | if data: |
| 337 | yield data |
| 338 | |
| 339 | def decode(self, line: str) -> ServerSentEvent | None: |
| 340 | # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501 |