Save the gradient scaler state if one exists.
(self, output_dir: str)
| 3282 | reissue_pt_warnings(caught_warnings) |
| 3283 | |
| 3284 | def _save_scaler(self, output_dir: str) -> None: |
| 3285 | """Save the gradient scaler state if one exists.""" |
| 3286 | # See if there is a scaler attribute |
| 3287 | try: |
| 3288 | scaler = self.accelerator.scaler |
| 3289 | except AttributeError: |
| 3290 | return |
| 3291 | if scaler is None: |
| 3292 | return |
| 3293 | if is_torch_xla_available(): |
| 3294 | xm.rendezvous("saving_scaler_state") |
| 3295 | with warnings.catch_warnings(record=True) as caught_warnings: |
| 3296 | xm.save(self.accelerator.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) |
| 3297 | reissue_pt_warnings(caught_warnings) |
| 3298 | |
| 3299 | # Save SCALER |
| 3300 | if self.args.should_save and not is_torch_xla_available(): |
| 3301 | with warnings.catch_warnings(record=True) as caught_warnings: |
| 3302 | torch.save(self.accelerator.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) |
| 3303 | reissue_pt_warnings(caught_warnings) |
| 3304 | |
| 3305 | # ---- Checkpoint Resuming ---- |
| 3306 |
no test coverage detected