If optimizer and scheduler states exist, load them.
(self, checkpoint: str | None)
| 3585 | set_rng_state_for_device("MUSA", torch.musa, checkpoint_rng_state, is_distributed) |
| 3586 | |
| 3587 | def _load_optimizer_and_scheduler(self, checkpoint: str | None) -> None: |
| 3588 | """If optimizer and scheduler states exist, load them.""" |
| 3589 | if checkpoint is None: |
| 3590 | return |
| 3591 | |
| 3592 | if self.is_deepspeed_enabled: |
| 3593 | # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init |
| 3594 | if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper): |
| 3595 | with warnings.catch_warnings(record=True) as caught_warnings: |
| 3596 | check_torch_load_is_safe() |
| 3597 | self.lr_scheduler.load_state_dict( |
| 3598 | torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True) |
| 3599 | ) |
| 3600 | reissue_pt_warnings(caught_warnings) |
| 3601 | return |
| 3602 | |
| 3603 | checkpoint_file_exists = ( |
| 3604 | glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") |
| 3605 | if is_sagemaker_mp_enabled() |
| 3606 | else ( |
| 3607 | os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) |
| 3608 | or os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME_BIN)) |
| 3609 | or ( |
| 3610 | os.path.isdir(checkpoint) |
| 3611 | and any( |
| 3612 | OPTIMIZER_NAME_BIN.split(".")[0] in folder_name |
| 3613 | for folder_name in os.listdir(checkpoint) |
| 3614 | if os.path.isdir(os.path.join(checkpoint, folder_name)) |
| 3615 | ) |
| 3616 | ) |
| 3617 | ) |
| 3618 | ) |
| 3619 | checkpoint_file_exists = ( |
| 3620 | glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}")) |
| 3621 | if self.is_fsdp_xla_v1_enabled |
| 3622 | else checkpoint_file_exists |
| 3623 | ) |
| 3624 | if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): |
| 3625 | # Load in optimizer and scheduler states |
| 3626 | if is_torch_xla_available(): |
| 3627 | # On TPU we have to take some extra precautions to properly load the states on the right device. |
| 3628 | if self.is_fsdp_xla_v1_enabled: |
| 3629 | check_torch_load_is_safe() |
| 3630 | optimizer_state = torch.load( |
| 3631 | os.path.join( |
| 3632 | checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}" |
| 3633 | ), |
| 3634 | map_location="cpu", |
| 3635 | weights_only=True, |
| 3636 | ) |
| 3637 | # We only need `optimizer` when resuming from checkpoint |
| 3638 | optimizer_state = optimizer_state["optimizer"] |
| 3639 | else: |
| 3640 | check_torch_load_is_safe() |
| 3641 | optimizer_state = torch.load( |
| 3642 | os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu", weights_only=True |
| 3643 | ) |
| 3644 | with warnings.catch_warnings(record=True) as caught_warnings: |
no test coverage detected