Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu uninitialized.
()
| 333 | |
| 334 | @lru_cache |
| 335 | def is_torch_mlu_available() -> bool: |
| 336 | """ |
| 337 | Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu |
| 338 | uninitialized. |
| 339 | """ |
| 340 | if not is_torch_available() or not _is_package_available("torch_mlu")[0]: |
| 341 | return False |
| 342 | |
| 343 | import torch |
| 344 | import torch_mlu # noqa: F401 |
| 345 | |
| 346 | pytorch_cndev_based_mlu_check_previous_value = os.environ.get("PYTORCH_CNDEV_BASED_MLU_CHECK") |
| 347 | try: |
| 348 | os.environ["PYTORCH_CNDEV_BASED_MLU_CHECK"] = str(1) |
| 349 | available = torch.mlu.is_available() if hasattr(torch, "mlu") else False |
| 350 | finally: |
| 351 | if pytorch_cndev_based_mlu_check_previous_value: |
| 352 | os.environ["PYTORCH_CNDEV_BASED_MLU_CHECK"] = pytorch_cndev_based_mlu_check_previous_value |
| 353 | else: |
| 354 | os.environ.pop("PYTORCH_CNDEV_BASED_MLU_CHECK", None) |
| 355 | |
| 356 | return available |
| 357 | |
| 358 | |
| 359 | @lru_cache |
no test coverage detected