Checks if XPU acceleration is available via stock PyTorch (>=2.6) and potentially if a XPU is in the environment.
(check_device: bool = False)
| 308 | |
| 309 | @lru_cache |
| 310 | def is_torch_xpu_available(check_device: bool = False) -> bool: |
| 311 | """ |
| 312 | Checks if XPU acceleration is available via stock PyTorch (>=2.6) and |
| 313 | potentially if a XPU is in the environment. |
| 314 | """ |
| 315 | if not is_torch_available(): |
| 316 | return False |
| 317 | |
| 318 | torch_version = version.parse(get_torch_version()) |
| 319 | if torch_version.major == 2 and torch_version.minor < 6: |
| 320 | return False |
| 321 | |
| 322 | import torch |
| 323 | |
| 324 | if check_device: |
| 325 | try: |
| 326 | # Will raise a RuntimeError if no XPU is found |
| 327 | _ = torch.xpu.device_count() |
| 328 | return torch.xpu.is_available() |
| 329 | except RuntimeError: |
| 330 | return False |
| 331 | return hasattr(torch, "xpu") and torch.xpu.is_available() |
| 332 | |
| 333 | |
| 334 | @lru_cache |