()
| 1631 | |
| 1632 | |
| 1633 | def is_sagemaker_mp_enabled() -> bool: |
| 1634 | # Get the sagemaker specific mp parameters from smp_options variable. |
| 1635 | smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}") |
| 1636 | try: |
| 1637 | # Parse it and check the field "partitions" is included, it is required for model parallel. |
| 1638 | smp_options = json.loads(smp_options) |
| 1639 | if "partitions" not in smp_options: |
| 1640 | return False |
| 1641 | except json.JSONDecodeError: |
| 1642 | return False |
| 1643 | |
| 1644 | # Get the sagemaker specific framework parameters from mpi_options variable. |
| 1645 | mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}") |
| 1646 | try: |
| 1647 | # Parse it and check the field "sagemaker_distributed_dataparallel_enabled". |
| 1648 | mpi_options = json.loads(mpi_options) |
| 1649 | if not mpi_options.get("sagemaker_mpi_enabled", False): |
| 1650 | return False |
| 1651 | except json.JSONDecodeError: |
| 1652 | return False |
| 1653 | # Lastly, check if the `smdistributed` module is present. |
| 1654 | return _is_package_available("smdistributed")[0] |
| 1655 | |
| 1656 | |
| 1657 | def is_training_run_on_sagemaker() -> bool: |
no test coverage detected