Checks if `torch_npu` is installed and potentially if a NPU is in the environment
(check_device=False)
| 287 | |
| 288 | @lru_cache |
| 289 | def is_torch_npu_available(check_device=False) -> bool: |
| 290 | "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" |
| 291 | if not is_torch_available() or not _is_package_available("torch_npu")[0]: |
| 292 | return False |
| 293 | |
| 294 | import torch |
| 295 | import torch_npu # noqa: F401 |
| 296 | |
| 297 | if check_device: |
| 298 | try: |
| 299 | # Will raise a RuntimeError if no NPU is found |
| 300 | if hasattr(torch, "npu"): |
| 301 | _ = torch.npu.device_count() |
| 302 | return torch.npu.is_available() |
| 303 | return False |
| 304 | except RuntimeError: |
| 305 | return False |
| 306 | return hasattr(torch, "npu") and torch.npu.is_available() |
| 307 | |
| 308 | |
| 309 | @lru_cache |