| 395 | |
| 396 | |
| 397 | def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _KERNEL_MODULE_MAPPING): |
| 398 | if kernel_name in mapping and isinstance(mapping[kernel_name], ModuleType): |
| 399 | return mapping[kernel_name] |
| 400 | if kernel_name not in _HUB_KERNEL_MAPPING: |
| 401 | logger.warning_once(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING") |
| 402 | mapping[kernel_name] = None |
| 403 | return None |
| 404 | if _kernels_available: |
| 405 | try: |
| 406 | repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"] |
| 407 | revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None) |
| 408 | version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None) |
| 409 | kernel = get_kernel(repo_id, revision=revision, version=version, allow_all_kernels=ALLOW_ALL_KERNELS) |
| 410 | mapping[kernel_name] = kernel |
| 411 | except FileNotFoundError as e: |
| 412 | mapping[kernel_name] = None |
| 413 | logger.warning_once(f"Failed to load kernel {kernel_name}: {e}") |
| 414 | except AssertionError: |
| 415 | # Happens when torch is built without an accelerator backend; fall back to slow path. |
| 416 | mapping[kernel_name] = None |
| 417 | |
| 418 | else: |
| 419 | # Try to import is_{kernel_name}_available from ..utils |
| 420 | import importlib |
| 421 | |
| 422 | new_kernel_name = kernel_name.replace("-", "_") |
| 423 | func_name = f"is_{new_kernel_name}_available" |
| 424 | |
| 425 | try: |
| 426 | utils_mod = importlib.import_module("..utils.import_utils", __package__) |
| 427 | is_kernel_available = getattr(utils_mod, func_name, None) |
| 428 | except Exception: |
| 429 | is_kernel_available = None |
| 430 | |
| 431 | if callable(is_kernel_available) and is_kernel_available(): |
| 432 | # Try to import the module "{kernel_name}" from parent package level |
| 433 | try: |
| 434 | module = importlib.import_module(f"{new_kernel_name}") |
| 435 | mapping[kernel_name] = module |
| 436 | return module |
| 437 | except Exception: |
| 438 | mapping[kernel_name] = None |
| 439 | else: |
| 440 | mapping[kernel_name] = None |
| 441 | |
| 442 | return mapping[kernel_name] |
| 443 | |
| 444 | |
| 445 | def get_kernel( |