Restore random number generator states from a checkpoint.
(self, checkpoint: str | None)
| 3540 | ) |
| 3541 | |
| 3542 | def _load_rng_state(self, checkpoint: str | None) -> None: |
| 3543 | """Restore random number generator states from a checkpoint.""" |
| 3544 | # Load RNG states from `checkpoint` |
| 3545 | if checkpoint is None: |
| 3546 | return |
| 3547 | |
| 3548 | if self.args.world_size > 1: |
| 3549 | process_index = self.args.process_index |
| 3550 | rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth") |
| 3551 | if not os.path.isfile(rng_file): |
| 3552 | logger.info( |
| 3553 | f"Didn't find an RNG file for process {process_index}, if you are resuming a training that " |
| 3554 | "wasn't launched in a distributed fashion, reproducibility is not guaranteed." |
| 3555 | ) |
| 3556 | return |
| 3557 | else: |
| 3558 | rng_file = os.path.join(checkpoint, "rng_state.pth") |
| 3559 | if not os.path.isfile(rng_file): |
| 3560 | logger.info( |
| 3561 | "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " |
| 3562 | "fashion, reproducibility is not guaranteed." |
| 3563 | ) |
| 3564 | return |
| 3565 | |
| 3566 | with safe_globals(): |
| 3567 | check_torch_load_is_safe() |
| 3568 | checkpoint_rng_state = torch.load(rng_file, weights_only=True) |
| 3569 | random.setstate(checkpoint_rng_state["python"]) |
| 3570 | np.random.set_state(checkpoint_rng_state["numpy"]) |
| 3571 | torch.random.set_rng_state(checkpoint_rng_state["cpu"]) |
| 3572 | if is_torch_xla_available(): |
| 3573 | xm.set_rng_state(checkpoint_rng_state["xla"]) |
| 3574 | |
| 3575 | is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED |
| 3576 | if torch.cuda.is_available(): |
| 3577 | set_rng_state_for_device("CUDA", torch.cuda, checkpoint_rng_state, is_distributed) |
| 3578 | if is_torch_npu_available(): |
| 3579 | set_rng_state_for_device("NPU", torch.npu, checkpoint_rng_state, is_distributed) |
| 3580 | if is_torch_hpu_available(): |
| 3581 | set_rng_state_for_device("HPU", torch.hpu, checkpoint_rng_state, is_distributed) |
| 3582 | if is_torch_mlu_available(): |
| 3583 | set_rng_state_for_device("MLU", torch.mlu, checkpoint_rng_state, is_distributed) |
| 3584 | if is_torch_musa_available(): |
| 3585 | set_rng_state_for_device("MUSA", torch.musa, checkpoint_rng_state, is_distributed) |
| 3586 | |
| 3587 | def _load_optimizer_and_scheduler(self, checkpoint: str | None) -> None: |
| 3588 | """If optimizer and scheduler states exist, load them.""" |
no test coverage detected