MCPcopy
hub / github.com/huggingface/transformers / save_model

Method save_model

src/transformers/trainer.py:3775–3831  ·  view source on GitHub ↗

Will save the model, so you can reload it using `from_pretrained()`. Will only save from the main process.

(self, output_dir: str | None = None, _internal_call: bool = False)

Source from the content-addressed store, hash-verified

3773 # ---- Saving & Serialization ----
3774
3775 def save_model(self, output_dir: str | None = None, _internal_call: bool = False) -> None:
3776 """
3777 Will save the model, so you can reload it using `from_pretrained()`.
3778
3779 Will only save from the main process.
3780 """
3781
3782 if output_dir is None:
3783 output_dir = self.args.output_dir
3784
3785 if is_torch_xla_available():
3786 save_tpu_checkpoint(
3787 self.model, self.args, self.accelerator, self.processing_class, self.is_fsdp_xla_v1_enabled, output_dir
3788 )
3789 elif is_sagemaker_mp_enabled():
3790 # Calling the state_dict needs to be done on the wrapped model and on all processes.
3791 os.makedirs(output_dir, exist_ok=True)
3792 state_dict = self.model_wrapped.state_dict()
3793 if self.args.should_save:
3794 self._save(output_dir, state_dict=state_dict)
3795 Path(os.path.join(output_dir, "user_content.pt")).touch()
3796 elif self.is_fsdp_enabled:
3797 if "FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type):
3798 state_dict = self.accelerator.get_state_dict(self.model)
3799 if self.args.should_save:
3800 self._save(output_dir, state_dict=state_dict)
3801 elif self.is_deepspeed_enabled:
3802 try:
3803 accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
3804 inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
3805 )
3806 zero3_sharding = self.deepspeed.config.get("zero_optimization", {}).get("stage", None) == 3
3807 if accept_exclude_frozen_parameters and _is_peft_model(self.model) and zero3_sharding:
3808 # When using PEFT with DeepSpeed ZeRO Stage 3,
3809 # we do not need to load the frozen parameters
3810 state_dict = self.deepspeed._zero3_consolidated_16bit_state_dict(exclude_frozen_parameters=True)
3811 else:
3812 state_dict = self.accelerator.get_state_dict(self.deepspeed)
3813 if self.args.should_save:
3814 self._save(output_dir, state_dict=state_dict)
3815 except ValueError:
3816 logger.warning(
3817 " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
3818 " zero_to_fp32.py to recover weights"
3819 )
3820 if self.args.should_save:
3821 self._save(output_dir, state_dict={})
3822 # remove the dummy state_dict
3823 remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
3824 self.model_wrapped.save_checkpoint(output_dir)
3825
3826 elif self.args.should_save:
3827 self._save(output_dir)
3828
3829 # Push to the Hub when `save_model` is called by the user.
3830 if self.args.push_to_hub and not _internal_call:
3831 self.push_to_hub(commit_message="Model save", revision=self.args.hub_revision)
3832

Callers 15

_save_checkpointMethod · 0.95
push_to_hubMethod · 0.95
_tune_save_checkpointMethod · 0.95
on_train_endMethod · 0.95
on_train_endMethod · 0.95
mainFunction · 0.95
mainFunction · 0.95
mainFunction · 0.95

Calls 12

_saveMethod · 0.95
push_to_hubMethod · 0.95
is_torch_xla_availableFunction · 0.85
save_tpu_checkpointFunction · 0.85
is_sagemaker_mp_enabledFunction · 0.85
_is_peft_modelFunction · 0.85
remove_dummy_checkpointFunction · 0.85
joinMethod · 0.80
warningMethod · 0.80
state_dictMethod · 0.45
keysMethod · 0.45
getMethod · 0.45