MCPcopy
hub / github.com/huggingface/transformers / is_torch_hpu_available

Function is_torch_hpu_available

src/transformers/utils/import_utils.py:408–500  ·  view source on GitHub ↗

Checks if `torch.hpu` is available and potentially if a HPU is in the environment

()

Source from the content-addressed store, hash-verified

406
407@lru_cache
408def is_torch_hpu_available() -> bool:
409 "Checks if `torch.hpu` is available and potentially if a HPU is in the environment"
410 if (
411 not is_torch_available()
412 or not _is_package_available("habana_frameworks")[0]
413 or not _is_package_available("habana_frameworks.torch")[0]
414 ):
415 return False
416
417 torch_hpu_min_accelerate_version = "1.5.0"
418 accelerate_available, accelerate_version = _is_package_available("accelerate", return_version=True)
419 if accelerate_available and version.parse(accelerate_version) < version.parse(torch_hpu_min_accelerate_version):
420 return False
421
422 import torch
423
424 if os.environ.get("PT_HPU_LAZY_MODE", "1") == "1":
425 # import habana_frameworks.torch in case of lazy mode to patch torch with torch.hpu
426 import habana_frameworks.torch # noqa: F401
427
428 if not hasattr(torch, "hpu") or not torch.hpu.is_available():
429 return False
430
431 # We patch torch.gather for int64 tensors to avoid a bug on Gaudi
432 # Graph compile failed with synStatus 26 [Generic failure]
433 # This can be removed once bug is fixed but for now we need it.
434 original_gather = torch.gather
435
436 def patched_gather(input: torch.Tensor, dim: int, index: torch.LongTensor) -> torch.Tensor:
437 if input.dtype == torch.int64 and input.device.type == "hpu":
438 return original_gather(input.to(torch.int32), dim, index).to(torch.int64)
439 else:
440 return original_gather(input, dim, index)
441
442 torch.gather = patched_gather
443 torch.Tensor.gather = patched_gather
444
445 original_take_along_dim = torch.take_along_dim
446
447 def patched_take_along_dim(input: torch.Tensor, indices: torch.LongTensor, dim: int | None = None) -> torch.Tensor:
448 if input.dtype == torch.int64 and input.device.type == "hpu":
449 return original_take_along_dim(input.to(torch.int32), indices, dim).to(torch.int64)
450 else:
451 return original_take_along_dim(input, indices, dim)
452
453 torch.take_along_dim = patched_take_along_dim
454
455 original_cholesky = torch.linalg.cholesky
456
457 def safe_cholesky(A, *args, **kwargs):
458 output = original_cholesky(A, *args, **kwargs)
459
460 if torch.isnan(output).any():
461 jitter_value = 1e-9
462 diag_jitter = torch.eye(A.size(-1), dtype=A.dtype, device=A.device) * jitter_value
463 output = original_cholesky(A + diag_jitter, *args, **kwargs)
464
465 return output

Callers 15

print_env.pyFile · 0.90
set_seedFunction · 0.85
__init__Method · 0.85
startMethod · 0.85
stopMethod · 0.85
require_torch_multi_hpuFunction · 0.85
testing_utils.pyFile · 0.85
__post_init__Method · 0.85
_setup_devicesMethod · 0.85
_save_rng_stateMethod · 0.85
_load_rng_stateMethod · 0.85
is_habana_gaudi1Function · 0.85

Calls 5

is_torch_availableFunction · 0.85
_is_package_availableFunction · 0.85
parseMethod · 0.45
getMethod · 0.45
is_availableMethod · 0.45

Tested by 1

require_torch_multi_hpuFunction · 0.68