MCPcopy
hub / github.com/huggingface/transformers / _load_best_model

Method _load_best_model

src/transformers/trainer.py:3436–3540  ·  view source on GitHub ↗

Load the best model found during training based on the tracked metric.

(self)

Source from the content-addressed store, hash-verified

3434 self._issue_warnings_after_load(load_result)
3435
3436 def _load_best_model(self) -> None:
3437 """Load the best model found during training based on the tracked metric."""
3438 logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
3439 best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
3440 best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
3441 best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
3442 best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
3443
3444 model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
3445 if self.is_deepspeed_enabled:
3446 deepspeed_load_checkpoint(
3447 self.model_wrapped,
3448 self.state.best_model_checkpoint,
3449 load_module_strict=not _is_peft_model(self.model),
3450 )
3451 elif self.is_fsdp_enabled:
3452 load_result = load_fsdp_model(
3453 self.accelerator.state.fsdp_plugin,
3454 self.accelerator,
3455 model,
3456 self.state.best_model_checkpoint,
3457 **get_fsdp_ckpt_kwargs(),
3458 )
3459 elif (
3460 os.path.exists(best_model_path)
3461 or os.path.exists(best_safe_model_path)
3462 or os.path.exists(best_adapter_model_path)
3463 or os.path.exists(best_safe_adapter_model_path)
3464 ):
3465 has_been_loaded = True
3466 if is_sagemaker_mp_enabled():
3467 smp.resume_from_checkpoint(
3468 path=self.state.best_model_checkpoint,
3469 tag=WEIGHTS_NAME,
3470 partial=False,
3471 load_optimizer=False,
3472 )
3473 else:
3474 if _is_peft_model(model):
3475 # If training a model using PEFT, assume that adapter have been saved properly.
3476 if hasattr(model, "active_adapters") and hasattr(model, "load_adapter"):
3477 active_adapter = model.active_adapters[0]
3478 if len(model.active_adapters) > 1:
3479 logger.warning("Detected multiple active adapters, will only consider the first one")
3480
3481 if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
3482 try:
3483 model.load_adapter(self.state.best_model_checkpoint, active_adapter)
3484 except RuntimeError as exc:
3485 if model.peft_config[active_adapter].is_prompt_learning:
3486 # for context: https://github.com/huggingface/peft/issues/2256
3487 msg = (
3488 "When using prompt learning PEFT methods such as "
3489 f"{model.peft_config[active_adapter].peft_type.value}, setting "
3490 "load_best_model_at_end=True can lead to errors, it is recommended "
3491 "to set this to False and to load the model manually from the checkpoint "
3492 "directory using PeftModel.from_pretrained(base_model, <path>) after training "
3493 "has finished."

Callers 1

_finalize_trainingMethod · 0.95

Calls 12

is_sagemaker_mp_enabledFunction · 0.85
_is_peft_modelFunction · 0.85
get_fsdp_ckpt_kwargsFunction · 0.85
check_torch_load_is_safeFunction · 0.85
load_sharded_checkpointFunction · 0.85
joinMethod · 0.80
warningMethod · 0.80
infoMethod · 0.45
load_adapterMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected