Load the best model found during training based on the tracked metric.
(self)
| 3434 | self._issue_warnings_after_load(load_result) |
| 3435 | |
| 3436 | def _load_best_model(self) -> None: |
| 3437 | """Load the best model found during training based on the tracked metric.""" |
| 3438 | logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") |
| 3439 | best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) |
| 3440 | best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) |
| 3441 | best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME) |
| 3442 | best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) |
| 3443 | |
| 3444 | model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model |
| 3445 | if self.is_deepspeed_enabled: |
| 3446 | deepspeed_load_checkpoint( |
| 3447 | self.model_wrapped, |
| 3448 | self.state.best_model_checkpoint, |
| 3449 | load_module_strict=not _is_peft_model(self.model), |
| 3450 | ) |
| 3451 | elif self.is_fsdp_enabled: |
| 3452 | load_result = load_fsdp_model( |
| 3453 | self.accelerator.state.fsdp_plugin, |
| 3454 | self.accelerator, |
| 3455 | model, |
| 3456 | self.state.best_model_checkpoint, |
| 3457 | **get_fsdp_ckpt_kwargs(), |
| 3458 | ) |
| 3459 | elif ( |
| 3460 | os.path.exists(best_model_path) |
| 3461 | or os.path.exists(best_safe_model_path) |
| 3462 | or os.path.exists(best_adapter_model_path) |
| 3463 | or os.path.exists(best_safe_adapter_model_path) |
| 3464 | ): |
| 3465 | has_been_loaded = True |
| 3466 | if is_sagemaker_mp_enabled(): |
| 3467 | smp.resume_from_checkpoint( |
| 3468 | path=self.state.best_model_checkpoint, |
| 3469 | tag=WEIGHTS_NAME, |
| 3470 | partial=False, |
| 3471 | load_optimizer=False, |
| 3472 | ) |
| 3473 | else: |
| 3474 | if _is_peft_model(model): |
| 3475 | # If training a model using PEFT, assume that adapter have been saved properly. |
| 3476 | if hasattr(model, "active_adapters") and hasattr(model, "load_adapter"): |
| 3477 | active_adapter = model.active_adapters[0] |
| 3478 | if len(model.active_adapters) > 1: |
| 3479 | logger.warning("Detected multiple active adapters, will only consider the first one") |
| 3480 | |
| 3481 | if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path): |
| 3482 | try: |
| 3483 | model.load_adapter(self.state.best_model_checkpoint, active_adapter) |
| 3484 | except RuntimeError as exc: |
| 3485 | if model.peft_config[active_adapter].is_prompt_learning: |
| 3486 | # for context: https://github.com/huggingface/peft/issues/2256 |
| 3487 | msg = ( |
| 3488 | "When using prompt learning PEFT methods such as " |
| 3489 | f"{model.peft_config[active_adapter].peft_type.value}, setting " |
| 3490 | "load_best_model_at_end=True can lead to errors, it is recommended " |
| 3491 | "to set this to False and to load the model manually from the checkpoint " |
| 3492 | "directory using PeftModel.from_pretrained(base_model, <path>) after training " |
| 3493 | "has finished." |
no test coverage detected