MCPcopy
hub / github.com/huggingface/transformers / _load_optimizer_and_scheduler

Method _load_optimizer_and_scheduler

src/transformers/trainer.py:3587–3689  ·  view source on GitHub ↗

If optimizer and scheduler states exist, load them.

(self, checkpoint: str | None)

Source from the content-addressed store, hash-verified

3585 set_rng_state_for_device("MUSA", torch.musa, checkpoint_rng_state, is_distributed)
3586
3587 def _load_optimizer_and_scheduler(self, checkpoint: str | None) -> None:
3588 """If optimizer and scheduler states exist, load them."""
3589 if checkpoint is None:
3590 return
3591
3592 if self.is_deepspeed_enabled:
3593 # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
3594 if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
3595 with warnings.catch_warnings(record=True) as caught_warnings:
3596 check_torch_load_is_safe()
3597 self.lr_scheduler.load_state_dict(
3598 torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
3599 )
3600 reissue_pt_warnings(caught_warnings)
3601 return
3602
3603 checkpoint_file_exists = (
3604 glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
3605 if is_sagemaker_mp_enabled()
3606 else (
3607 os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
3608 or os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME_BIN))
3609 or (
3610 os.path.isdir(checkpoint)
3611 and any(
3612 OPTIMIZER_NAME_BIN.split(".")[0] in folder_name
3613 for folder_name in os.listdir(checkpoint)
3614 if os.path.isdir(os.path.join(checkpoint, folder_name))
3615 )
3616 )
3617 )
3618 )
3619 checkpoint_file_exists = (
3620 glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}"))
3621 if self.is_fsdp_xla_v1_enabled
3622 else checkpoint_file_exists
3623 )
3624 if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
3625 # Load in optimizer and scheduler states
3626 if is_torch_xla_available():
3627 # On TPU we have to take some extra precautions to properly load the states on the right device.
3628 if self.is_fsdp_xla_v1_enabled:
3629 check_torch_load_is_safe()
3630 optimizer_state = torch.load(
3631 os.path.join(
3632 checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
3633 ),
3634 map_location="cpu",
3635 weights_only=True,
3636 )
3637 # We only need `optimizer` when resuming from checkpoint
3638 optimizer_state = optimizer_state["optimizer"]
3639 else:
3640 check_torch_load_is_safe()
3641 optimizer_state = torch.load(
3642 os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu", weights_only=True
3643 )
3644 with warnings.catch_warnings(record=True) as caught_warnings:

Callers 1

_prepare_for_trainingMethod · 0.95

Calls 8

check_torch_load_is_safeFunction · 0.85
reissue_pt_warningsFunction · 0.85
is_sagemaker_mp_enabledFunction · 0.85
is_torch_xla_availableFunction · 0.85
get_fsdp_ckpt_kwargsFunction · 0.85
joinMethod · 0.80
splitMethod · 0.80
load_state_dictMethod · 0.45

Tested by

no test coverage detected