MCPcopy
hub / github.com/huggingface/transformers / _save_checkpoint

Method _save_checkpoint

src/transformers/trainer.py:3060–3127  ·  view source on GitHub ↗

Save model checkpoint, optimizer, scheduler, scaler, RNG states, and trainer state.

(self, model: nn.Module, trial: "optuna.Trial | dict[str, Any] | None")

Source from the content-addressed store, hash-verified

3058 return run_dir
3059
3060 def _save_checkpoint(self, model: nn.Module, trial: "optuna.Trial | dict[str, Any] | None") -> None:
3061 """Save model checkpoint, optimizer, scheduler, scaler, RNG states, and trainer state."""
3062 # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
3063 # want to save except FullyShardedDDP.
3064 # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
3065
3066 # Save model checkpoint
3067 checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
3068
3069 if self.hp_search_backend is None and trial is None:
3070 self.store_flos()
3071
3072 run_dir = self._get_output_dir(trial=trial)
3073 output_dir = os.path.join(run_dir, checkpoint_folder)
3074 self.save_model(output_dir, _internal_call=True)
3075
3076 if (
3077 self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH, SaveStrategy.BEST]
3078 and self.state.best_global_step
3079 ):
3080 # Wait for everyone to get here so we are sure the model has been saved by process 0
3081 # before we check if the best_checkpoint_dir exists
3082 if is_torch_xla_available():
3083 xm.rendezvous("load_best_model_at_end")
3084 elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
3085 dist.barrier()
3086 elif is_sagemaker_mp_enabled():
3087 smp.barrier()
3088
3089 best_checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.best_global_step}"
3090 best_checkpoint_dir = os.path.join(run_dir, best_checkpoint_folder)
3091
3092 if os.path.exists(best_checkpoint_dir):
3093 self.state.best_model_checkpoint = best_checkpoint_dir
3094
3095 if not self.args.save_only_model:
3096 # Save optimizer and scheduler
3097 self._save_optimizer_and_scheduler(output_dir)
3098 self._save_scaler(output_dir)
3099 # Save RNG state
3100 self._save_rng_state(output_dir)
3101
3102 # Save the Trainer state
3103 if self.args.should_save:
3104 # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
3105 for cb in [
3106 cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
3107 ]:
3108 cb_name = cb.__class__.__name__
3109 cb_state = cb.state()
3110 if isinstance(self.state.stateful_callbacks[cb_name], list):
3111 self.state.stateful_callbacks[cb_name].append(cb_state)
3112 else:
3113 self.state.stateful_callbacks[cb_name] = cb_state
3114 self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
3115
3116 if self.args.push_to_hub:
3117 self._push_from_checkpoint(output_dir)

Callers 2

Calls 13

store_flosMethod · 0.95
_get_output_dirMethod · 0.95
save_modelMethod · 0.95
_save_scalerMethod · 0.95
_save_rng_stateMethod · 0.95
_push_from_checkpointMethod · 0.95
is_torch_xla_availableFunction · 0.85
is_sagemaker_mp_enabledFunction · 0.85
rotate_checkpointsFunction · 0.85
joinMethod · 0.80
save_to_jsonMethod · 0.80

Tested by

no test coverage detected