Checks if `torch_musa` is installed and potentially if a MUSA is in the environment
(check_device=False)
| 358 | |
| 359 | @lru_cache |
| 360 | def is_torch_musa_available(check_device=False) -> bool: |
| 361 | "Checks if `torch_musa` is installed and potentially if a MUSA is in the environment" |
| 362 | if not is_torch_available() or not _is_package_available("torch_musa")[0]: |
| 363 | return False |
| 364 | |
| 365 | import torch |
| 366 | import torch_musa # noqa: F401 |
| 367 | |
| 368 | torch_musa_min_version = "0.33.0" |
| 369 | accelerate_available, accelerate_version = _is_package_available("accelerate", return_version=True) |
| 370 | if accelerate_available and version.parse(accelerate_version) < version.parse(torch_musa_min_version): |
| 371 | return False |
| 372 | |
| 373 | if check_device: |
| 374 | try: |
| 375 | # Will raise a RuntimeError if no MUSA is found |
| 376 | if hasattr(torch, "musa"): |
| 377 | _ = torch.musa.device_count() |
| 378 | return torch.musa.is_available() |
| 379 | return False |
| 380 | except RuntimeError: |
| 381 | return False |
| 382 | return hasattr(torch, "musa") and torch.musa.is_available() |
| 383 | |
| 384 | |
| 385 | @lru_cache |
no test coverage detected