Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set the USE_TORCH_XLA to false.
(check_is_tpu=False, check_is_gpu=False)
| 384 | |
| 385 | @lru_cache |
| 386 | def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False) -> bool: |
| 387 | """ |
| 388 | Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set |
| 389 | the USE_TORCH_XLA to false. |
| 390 | """ |
| 391 | assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true." |
| 392 | |
| 393 | torch_xla_available = USE_TORCH_XLA in ENV_VARS_TRUE_VALUES and _is_package_available("torch_xla")[0] |
| 394 | if not torch_xla_available: |
| 395 | return False |
| 396 | |
| 397 | import torch_xla |
| 398 | |
| 399 | if check_is_gpu: |
| 400 | return torch_xla.runtime.device_type() in ["GPU", "CUDA"] |
| 401 | elif check_is_tpu: |
| 402 | return torch_xla.runtime.device_type() == "TPU" |
| 403 | |
| 404 | return True |
| 405 | |
| 406 | |
| 407 | @lru_cache |