Create a request handler class bound to a specific result instance.
(result: _OAuthResult, expected_state: str)
| 220 | |
| 221 | |
| 222 | def _make_handler(result: _OAuthResult, expected_state: str) -> type[BaseHTTPRequestHandler]: |
| 223 | """Create a request handler class bound to a specific result instance.""" |
| 224 | |
| 225 | class Handler(BaseHTTPRequestHandler): |
| 226 | def do_GET(self) -> None: # noqa: N802 |
| 227 | parsed = urlparse(self.path) |
| 228 | |
| 229 | # Ignore non-callback requests (e.g. /favicon.ico). |
| 230 | if parsed.path != "/callback": |
| 231 | self.send_response(404) |
| 232 | self.end_headers() |
| 233 | return |
| 234 | |
| 235 | # Already resolved — ignore duplicate requests. |
| 236 | if result.ready.is_set(): |
| 237 | self._respond(200, _SUCCESS_HTML) |
| 238 | return |
| 239 | |
| 240 | qs = parse_qs(parsed.query) |
| 241 | |
| 242 | # Validate state to prevent CSRF. |
| 243 | received_state = qs.get("state", [None])[0] |
| 244 | if received_state != expected_state: |
| 245 | result.set_error("Invalid state parameter — possible CSRF attack") |
| 246 | self._respond(400, _ERROR_HTML.format(error="Invalid state parameter")) |
| 247 | return |
| 248 | |
| 249 | error = qs.get("error", [None])[0] |
| 250 | if error: |
| 251 | desc = qs.get("error_description", [error])[0] |
| 252 | result.set_error(desc) |
| 253 | self._respond(400, _ERROR_HTML.format(error=html_mod.escape(desc))) |
| 254 | return |
| 255 | |
| 256 | code = qs.get("code", [None])[0] |
| 257 | if code: |
| 258 | result.set_code(code) |
| 259 | self._respond(200, _SUCCESS_HTML) |
| 260 | return |
| 261 | |
| 262 | result.set_error(f"No authorization code in callback (received: {self.path})") |
| 263 | self._respond(400, _ERROR_HTML.format(error="Missing code parameter")) |
| 264 | |
| 265 | def _respond(self, status: int, html: str) -> None: |
| 266 | self.send_response(status) |
| 267 | self.send_header("Content-Type", "text/html; charset=utf-8") |
| 268 | self.end_headers() |
| 269 | self.wfile.write(html.encode()) |
| 270 | |
| 271 | def log_message(self, format: str, *args: object) -> None: # noqa: A002 |
| 272 | pass |
| 273 | |
| 274 | return Handler |
| 275 | |
| 276 | |
| 277 | class _ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer): |