Checks if `torch.hpu` is available and potentially if a HPU is in the environment
()
| 406 | |
| 407 | @lru_cache |
| 408 | def is_torch_hpu_available() -> bool: |
| 409 | "Checks if `torch.hpu` is available and potentially if a HPU is in the environment" |
| 410 | if ( |
| 411 | not is_torch_available() |
| 412 | or not _is_package_available("habana_frameworks")[0] |
| 413 | or not _is_package_available("habana_frameworks.torch")[0] |
| 414 | ): |
| 415 | return False |
| 416 | |
| 417 | torch_hpu_min_accelerate_version = "1.5.0" |
| 418 | accelerate_available, accelerate_version = _is_package_available("accelerate", return_version=True) |
| 419 | if accelerate_available and version.parse(accelerate_version) < version.parse(torch_hpu_min_accelerate_version): |
| 420 | return False |
| 421 | |
| 422 | import torch |
| 423 | |
| 424 | if os.environ.get("PT_HPU_LAZY_MODE", "1") == "1": |
| 425 | # import habana_frameworks.torch in case of lazy mode to patch torch with torch.hpu |
| 426 | import habana_frameworks.torch # noqa: F401 |
| 427 | |
| 428 | if not hasattr(torch, "hpu") or not torch.hpu.is_available(): |
| 429 | return False |
| 430 | |
| 431 | # We patch torch.gather for int64 tensors to avoid a bug on Gaudi |
| 432 | # Graph compile failed with synStatus 26 [Generic failure] |
| 433 | # This can be removed once bug is fixed but for now we need it. |
| 434 | original_gather = torch.gather |
| 435 | |
| 436 | def patched_gather(input: torch.Tensor, dim: int, index: torch.LongTensor) -> torch.Tensor: |
| 437 | if input.dtype == torch.int64 and input.device.type == "hpu": |
| 438 | return original_gather(input.to(torch.int32), dim, index).to(torch.int64) |
| 439 | else: |
| 440 | return original_gather(input, dim, index) |
| 441 | |
| 442 | torch.gather = patched_gather |
| 443 | torch.Tensor.gather = patched_gather |
| 444 | |
| 445 | original_take_along_dim = torch.take_along_dim |
| 446 | |
| 447 | def patched_take_along_dim(input: torch.Tensor, indices: torch.LongTensor, dim: int | None = None) -> torch.Tensor: |
| 448 | if input.dtype == torch.int64 and input.device.type == "hpu": |
| 449 | return original_take_along_dim(input.to(torch.int32), indices, dim).to(torch.int64) |
| 450 | else: |
| 451 | return original_take_along_dim(input, indices, dim) |
| 452 | |
| 453 | torch.take_along_dim = patched_take_along_dim |
| 454 | |
| 455 | original_cholesky = torch.linalg.cholesky |
| 456 | |
| 457 | def safe_cholesky(A, *args, **kwargs): |
| 458 | output = original_cholesky(A, *args, **kwargs) |
| 459 | |
| 460 | if torch.isnan(output).any(): |
| 461 | jitter_value = 1e-9 |
| 462 | diag_jitter = torch.eye(A.size(-1), dtype=A.dtype, device=A.device) * jitter_value |
| 463 | output = original_cholesky(A + diag_jitter, *args, **kwargs) |
| 464 | |
| 465 | return output |