Will save the model, so you can reload it using `from_pretrained()`. Will only save from the main process.
(self, output_dir: str | None = None, _internal_call: bool = False)
| 3773 | # ---- Saving & Serialization ---- |
| 3774 | |
| 3775 | def save_model(self, output_dir: str | None = None, _internal_call: bool = False) -> None: |
| 3776 | """ |
| 3777 | Will save the model, so you can reload it using `from_pretrained()`. |
| 3778 | |
| 3779 | Will only save from the main process. |
| 3780 | """ |
| 3781 | |
| 3782 | if output_dir is None: |
| 3783 | output_dir = self.args.output_dir |
| 3784 | |
| 3785 | if is_torch_xla_available(): |
| 3786 | save_tpu_checkpoint( |
| 3787 | self.model, self.args, self.accelerator, self.processing_class, self.is_fsdp_xla_v1_enabled, output_dir |
| 3788 | ) |
| 3789 | elif is_sagemaker_mp_enabled(): |
| 3790 | # Calling the state_dict needs to be done on the wrapped model and on all processes. |
| 3791 | os.makedirs(output_dir, exist_ok=True) |
| 3792 | state_dict = self.model_wrapped.state_dict() |
| 3793 | if self.args.should_save: |
| 3794 | self._save(output_dir, state_dict=state_dict) |
| 3795 | Path(os.path.join(output_dir, "user_content.pt")).touch() |
| 3796 | elif self.is_fsdp_enabled: |
| 3797 | if "FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type): |
| 3798 | state_dict = self.accelerator.get_state_dict(self.model) |
| 3799 | if self.args.should_save: |
| 3800 | self._save(output_dir, state_dict=state_dict) |
| 3801 | elif self.is_deepspeed_enabled: |
| 3802 | try: |
| 3803 | accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set( |
| 3804 | inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys() |
| 3805 | ) |
| 3806 | zero3_sharding = self.deepspeed.config.get("zero_optimization", {}).get("stage", None) == 3 |
| 3807 | if accept_exclude_frozen_parameters and _is_peft_model(self.model) and zero3_sharding: |
| 3808 | # When using PEFT with DeepSpeed ZeRO Stage 3, |
| 3809 | # we do not need to load the frozen parameters |
| 3810 | state_dict = self.deepspeed._zero3_consolidated_16bit_state_dict(exclude_frozen_parameters=True) |
| 3811 | else: |
| 3812 | state_dict = self.accelerator.get_state_dict(self.deepspeed) |
| 3813 | if self.args.should_save: |
| 3814 | self._save(output_dir, state_dict=state_dict) |
| 3815 | except ValueError: |
| 3816 | logger.warning( |
| 3817 | " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" |
| 3818 | " zero_to_fp32.py to recover weights" |
| 3819 | ) |
| 3820 | if self.args.should_save: |
| 3821 | self._save(output_dir, state_dict={}) |
| 3822 | # remove the dummy state_dict |
| 3823 | remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) |
| 3824 | self.model_wrapped.save_checkpoint(output_dir) |
| 3825 | |
| 3826 | elif self.args.should_save: |
| 3827 | self._save(output_dir) |
| 3828 | |
| 3829 | # Push to the Hub when `save_model` is called by the user. |
| 3830 | if self.args.push_to_hub and not _internal_call: |
| 3831 | self.push_to_hub(commit_message="Model save", revision=self.args.hub_revision) |
| 3832 |