MCPcopy
hub / github.com/huggingface/transformers / _load_from_checkpoint

Method _load_from_checkpoint

src/transformers/trainer.py:3307–3434  ·  view source on GitHub ↗

Load model weights from a checkpoint directory.

(self, resume_from_checkpoint: str, model: nn.Module | None = None)

Source from the content-addressed store, hash-verified

3305 # ---- Checkpoint Resuming ----
3306
3307 def _load_from_checkpoint(self, resume_from_checkpoint: str, model: nn.Module | None = None) -> None:
3308 """Load model weights from a checkpoint directory."""
3309 if model is None:
3310 model = self.model
3311
3312 config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
3313 adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)
3314 adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
3315 weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
3316 weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
3317 safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
3318 safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
3319 is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and (
3320 # this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
3321 any(
3322 FSDP_MODEL_NAME in folder_name
3323 for folder_name in os.listdir(resume_from_checkpoint)
3324 if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
3325 )
3326 # this checks the FSDP state dict when `FULL_STATE_DICT` is used
3327 or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
3328 )
3329 # if multiple adapters exist, they get saved in sub directories
3330 adapter_subdirs = (
3331 [
3332 folder_name
3333 for folder_name in os.listdir(resume_from_checkpoint)
3334 if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
3335 and (
3336 os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_WEIGHTS_NAME))
3337 or os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_SAFE_WEIGHTS_NAME))
3338 )
3339 ]
3340 if os.path.isdir(resume_from_checkpoint)
3341 else []
3342 )
3343
3344 if is_fsdp_ckpt and not self.is_fsdp_enabled:
3345 raise ValueError(f"Checkpoint found at {resume_from_checkpoint} is only supported when using PyTorch FSDP")
3346
3347 if not (
3348 any(
3349 os.path.isfile(f)
3350 for f in [
3351 weights_file,
3352 safe_weights_file,
3353 weights_index_file,
3354 safe_weights_index_file,
3355 adapter_weights_file,
3356 adapter_safe_weights_file,
3357 ]
3358 )
3359 or is_fsdp_ckpt
3360 or adapter_subdirs
3361 ):
3362 raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
3363
3364 logger.info(f"Loading model from {resume_from_checkpoint}.")

Callers 2

trainMethod · 0.95
_prepare_for_trainingMethod · 0.95

Calls 13

is_sagemaker_mp_enabledFunction · 0.85
get_fsdp_ckpt_kwargsFunction · 0.85
check_torch_load_is_safeFunction · 0.85
_is_peft_modelFunction · 0.85
load_sharded_checkpointFunction · 0.85
joinMethod · 0.80
warningMethod · 0.80
set_adapterMethod · 0.80
infoMethod · 0.45
from_json_fileMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected