If scaler state exists, load it.
(self, checkpoint: str | None)
| 3689 | reissue_pt_warnings(caught_warnings) |
| 3690 | |
| 3691 | def _load_scaler(self, checkpoint: str | None) -> None: |
| 3692 | """If scaler state exists, load it.""" |
| 3693 | if checkpoint is None: |
| 3694 | return |
| 3695 | |
| 3696 | checkpoint_file_exists = os.path.isfile(os.path.join(checkpoint, SCALER_NAME)) |
| 3697 | |
| 3698 | if checkpoint_file_exists: |
| 3699 | # On TPU we have to take some extra precautions to properly load the states on the right device. |
| 3700 | # Load in scaler states |
| 3701 | if is_torch_xla_available(): |
| 3702 | with warnings.catch_warnings(record=True) as caught_warnings: |
| 3703 | check_torch_load_is_safe() |
| 3704 | scaler_state = torch.load( |
| 3705 | os.path.join(checkpoint, SCALER_NAME), map_location="cpu", weights_only=True |
| 3706 | ) |
| 3707 | reissue_pt_warnings(caught_warnings) |
| 3708 | xm.send_cpu_data_to_device(scaler_state, self.args.device) |
| 3709 | self.accelerator.scaler.load_state_dict(scaler_state) |
| 3710 | else: |
| 3711 | with warnings.catch_warnings(record=True) as caught_warnings: |
| 3712 | check_torch_load_is_safe() |
| 3713 | self.accelerator.scaler.load_state_dict( |
| 3714 | torch.load(os.path.join(checkpoint, SCALER_NAME), weights_only=True) |
| 3715 | ) |
| 3716 | reissue_pt_warnings(caught_warnings) |
| 3717 | |
| 3718 | def _load_callback_state(self) -> None: |
| 3719 | """If callback states exist and were passed in, restore their states if enabled""" |
no test coverage detected