Save model checkpoint, optimizer, scheduler, scaler, RNG states, and trainer state.
(self, model: nn.Module, trial: "optuna.Trial | dict[str, Any] | None")
| 3058 | return run_dir |
| 3059 | |
| 3060 | def _save_checkpoint(self, model: nn.Module, trial: "optuna.Trial | dict[str, Any] | None") -> None: |
| 3061 | """Save model checkpoint, optimizer, scheduler, scaler, RNG states, and trainer state.""" |
| 3062 | # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we |
| 3063 | # want to save except FullyShardedDDP. |
| 3064 | # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" |
| 3065 | |
| 3066 | # Save model checkpoint |
| 3067 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" |
| 3068 | |
| 3069 | if self.hp_search_backend is None and trial is None: |
| 3070 | self.store_flos() |
| 3071 | |
| 3072 | run_dir = self._get_output_dir(trial=trial) |
| 3073 | output_dir = os.path.join(run_dir, checkpoint_folder) |
| 3074 | self.save_model(output_dir, _internal_call=True) |
| 3075 | |
| 3076 | if ( |
| 3077 | self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH, SaveStrategy.BEST] |
| 3078 | and self.state.best_global_step |
| 3079 | ): |
| 3080 | # Wait for everyone to get here so we are sure the model has been saved by process 0 |
| 3081 | # before we check if the best_checkpoint_dir exists |
| 3082 | if is_torch_xla_available(): |
| 3083 | xm.rendezvous("load_best_model_at_end") |
| 3084 | elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: |
| 3085 | dist.barrier() |
| 3086 | elif is_sagemaker_mp_enabled(): |
| 3087 | smp.barrier() |
| 3088 | |
| 3089 | best_checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.best_global_step}" |
| 3090 | best_checkpoint_dir = os.path.join(run_dir, best_checkpoint_folder) |
| 3091 | |
| 3092 | if os.path.exists(best_checkpoint_dir): |
| 3093 | self.state.best_model_checkpoint = best_checkpoint_dir |
| 3094 | |
| 3095 | if not self.args.save_only_model: |
| 3096 | # Save optimizer and scheduler |
| 3097 | self._save_optimizer_and_scheduler(output_dir) |
| 3098 | self._save_scaler(output_dir) |
| 3099 | # Save RNG state |
| 3100 | self._save_rng_state(output_dir) |
| 3101 | |
| 3102 | # Save the Trainer state |
| 3103 | if self.args.should_save: |
| 3104 | # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently |
| 3105 | for cb in [ |
| 3106 | cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) |
| 3107 | ]: |
| 3108 | cb_name = cb.__class__.__name__ |
| 3109 | cb_state = cb.state() |
| 3110 | if isinstance(self.state.stateful_callbacks[cb_name], list): |
| 3111 | self.state.stateful_callbacks[cb_name].append(cb_state) |
| 3112 | else: |
| 3113 | self.state.stateful_callbacks[cb_name] = cb_state |
| 3114 | self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) |
| 3115 | |
| 3116 | if self.args.push_to_hub: |
| 3117 | self._push_from_checkpoint(output_dir) |
no test coverage detected