Delete older checkpoints, keeping at most `save_total_limit`. Always preserves the most recent checkpoint and the best model checkpoint (if provided). Args: output_dir (`str`): The directory containing the checkpoints. save_total_limit (`int`, *optional*):
(
output_dir: str,
save_total_limit: int | None = None,
best_model_checkpoint: str | None = None,
use_mtime: bool = False,
checkpoint_prefix: str = PREFIX_CHECKPOINT_DIR,
)
| 340 | |
| 341 | |
| 342 | def rotate_checkpoints( |
| 343 | output_dir: str, |
| 344 | save_total_limit: int | None = None, |
| 345 | best_model_checkpoint: str | None = None, |
| 346 | use_mtime: bool = False, |
| 347 | checkpoint_prefix: str = PREFIX_CHECKPOINT_DIR, |
| 348 | ) -> None: |
| 349 | """ |
| 350 | Delete older checkpoints, keeping at most `save_total_limit`. |
| 351 | |
| 352 | Always preserves the most recent checkpoint and the best model checkpoint (if provided). |
| 353 | |
| 354 | Args: |
| 355 | output_dir (`str`): |
| 356 | The directory containing the checkpoints. |
| 357 | save_total_limit (`int`, *optional*): |
| 358 | Maximum number of checkpoints to keep. No deletion if `None` or <= 0. |
| 359 | best_model_checkpoint (`str`, *optional*): |
| 360 | Path to best checkpoint (will always be preserved). |
| 361 | use_mtime (`bool`, *optional*, defaults to `False`): |
| 362 | Whether to sort by modification time instead of step number. |
| 363 | checkpoint_prefix (`str`, *optional*, defaults to `"checkpoint"`): |
| 364 | The prefix used for checkpoint directory names. |
| 365 | """ |
| 366 | if save_total_limit is None or save_total_limit <= 0: |
| 367 | return |
| 368 | |
| 369 | checkpoints = sort_checkpoints(output_dir, checkpoint_prefix, use_mtime) |
| 370 | if len(checkpoints) <= save_total_limit: |
| 371 | return |
| 372 | |
| 373 | # Checkpoints that must not be deleted |
| 374 | protected = {checkpoints[-1]} # most recent, for resuming |
| 375 | if best_model_checkpoint is not None: |
| 376 | protected.add(str(Path(best_model_checkpoint))) |
| 377 | |
| 378 | # Delete oldest non-protected checkpoints until we have save_total_limit left |
| 379 | num_to_keep = max(save_total_limit, len(protected)) |
| 380 | remaining = len(checkpoints) |
| 381 | for checkpoint in checkpoints: |
| 382 | if remaining <= num_to_keep: |
| 383 | break |
| 384 | if checkpoint not in protected: |
| 385 | shutil.rmtree(checkpoint, ignore_errors=True) |
| 386 | remaining -= 1 |
| 387 | |
| 388 | |
| 389 | class IntervalStrategy(ExplicitEnum): |