MCPcopy
hub / github.com/huggingface/transformers / _load_scaler

Method _load_scaler

src/transformers/trainer.py:3691–3716  ·  view source on GitHub ↗

If scaler state exists, load it.

(self, checkpoint: str | None)

Source from the content-addressed store, hash-verified

3689 reissue_pt_warnings(caught_warnings)
3690
3691 def _load_scaler(self, checkpoint: str | None) -> None:
3692 """If scaler state exists, load it."""
3693 if checkpoint is None:
3694 return
3695
3696 checkpoint_file_exists = os.path.isfile(os.path.join(checkpoint, SCALER_NAME))
3697
3698 if checkpoint_file_exists:
3699 # On TPU we have to take some extra precautions to properly load the states on the right device.
3700 # Load in scaler states
3701 if is_torch_xla_available():
3702 with warnings.catch_warnings(record=True) as caught_warnings:
3703 check_torch_load_is_safe()
3704 scaler_state = torch.load(
3705 os.path.join(checkpoint, SCALER_NAME), map_location="cpu", weights_only=True
3706 )
3707 reissue_pt_warnings(caught_warnings)
3708 xm.send_cpu_data_to_device(scaler_state, self.args.device)
3709 self.accelerator.scaler.load_state_dict(scaler_state)
3710 else:
3711 with warnings.catch_warnings(record=True) as caught_warnings:
3712 check_torch_load_is_safe()
3713 self.accelerator.scaler.load_state_dict(
3714 torch.load(os.path.join(checkpoint, SCALER_NAME), weights_only=True)
3715 )
3716 reissue_pt_warnings(caught_warnings)
3717
3718 def _load_callback_state(self) -> None:
3719 """If callback states exist and were passed in, restore their states if enabled"""

Callers 1

_prepare_for_trainingMethod · 0.95

Calls 5

is_torch_xla_availableFunction · 0.85
check_torch_load_is_safeFunction · 0.85
reissue_pt_warningsFunction · 0.85
joinMethod · 0.80
load_state_dictMethod · 0.45

Tested by

no test coverage detected