(check_device: bool = False)
| 524 | |
| 525 | @lru_cache |
| 526 | def is_torch_tpu_available(check_device: bool = False) -> bool: |
| 527 | import torch |
| 528 | |
| 529 | if importlib.util.find_spec("torch_tpu") is None: |
| 530 | return False |
| 531 | |
| 532 | if check_device: |
| 533 | try: |
| 534 | import torch_tpu # noqa: F401 |
| 535 | |
| 536 | if hasattr(torch, "tpu") and torch.tpu.is_available(): |
| 537 | return torch.tpu.device_count() >= 1 |
| 538 | return False |
| 539 | except RuntimeError: |
| 540 | return False |
| 541 | |
| 542 | return hasattr(torch, "tpu") and torch.tpu.is_available() |
| 543 | |
| 544 | |
| 545 | @lru_cache |