Load model weights from a checkpoint directory.
(self, resume_from_checkpoint: str, model: nn.Module | None = None)
| 3305 | # ---- Checkpoint Resuming ---- |
| 3306 | |
| 3307 | def _load_from_checkpoint(self, resume_from_checkpoint: str, model: nn.Module | None = None) -> None: |
| 3308 | """Load model weights from a checkpoint directory.""" |
| 3309 | if model is None: |
| 3310 | model = self.model |
| 3311 | |
| 3312 | config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME) |
| 3313 | adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME) |
| 3314 | adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) |
| 3315 | weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME) |
| 3316 | weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) |
| 3317 | safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME) |
| 3318 | safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME) |
| 3319 | is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and ( |
| 3320 | # this checks the FSDP state dict when `SHARDED_STATE_DICT` is used |
| 3321 | any( |
| 3322 | FSDP_MODEL_NAME in folder_name |
| 3323 | for folder_name in os.listdir(resume_from_checkpoint) |
| 3324 | if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name)) |
| 3325 | ) |
| 3326 | # this checks the FSDP state dict when `FULL_STATE_DICT` is used |
| 3327 | or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin")) |
| 3328 | ) |
| 3329 | # if multiple adapters exist, they get saved in sub directories |
| 3330 | adapter_subdirs = ( |
| 3331 | [ |
| 3332 | folder_name |
| 3333 | for folder_name in os.listdir(resume_from_checkpoint) |
| 3334 | if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name)) |
| 3335 | and ( |
| 3336 | os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_WEIGHTS_NAME)) |
| 3337 | or os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_SAFE_WEIGHTS_NAME)) |
| 3338 | ) |
| 3339 | ] |
| 3340 | if os.path.isdir(resume_from_checkpoint) |
| 3341 | else [] |
| 3342 | ) |
| 3343 | |
| 3344 | if is_fsdp_ckpt and not self.is_fsdp_enabled: |
| 3345 | raise ValueError(f"Checkpoint found at {resume_from_checkpoint} is only supported when using PyTorch FSDP") |
| 3346 | |
| 3347 | if not ( |
| 3348 | any( |
| 3349 | os.path.isfile(f) |
| 3350 | for f in [ |
| 3351 | weights_file, |
| 3352 | safe_weights_file, |
| 3353 | weights_index_file, |
| 3354 | safe_weights_index_file, |
| 3355 | adapter_weights_file, |
| 3356 | adapter_safe_weights_file, |
| 3357 | ] |
| 3358 | ) |
| 3359 | or is_fsdp_ckpt |
| 3360 | or adapter_subdirs |
| 3361 | ): |
| 3362 | raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") |
| 3363 | |
| 3364 | logger.info(f"Loading model from {resume_from_checkpoint}.") |
no test coverage detected