MCPcopy
hub / github.com/huggingface/transformers / _load_rng_state

Method _load_rng_state

src/transformers/trainer.py:3542–3585  ·  view source on GitHub ↗

Restore random number generator states from a checkpoint.

(self, checkpoint: str | None)

Source from the content-addressed store, hash-verified

3540 )
3541
3542 def _load_rng_state(self, checkpoint: str | None) -> None:
3543 """Restore random number generator states from a checkpoint."""
3544 # Load RNG states from `checkpoint`
3545 if checkpoint is None:
3546 return
3547
3548 if self.args.world_size > 1:
3549 process_index = self.args.process_index
3550 rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
3551 if not os.path.isfile(rng_file):
3552 logger.info(
3553 f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
3554 "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
3555 )
3556 return
3557 else:
3558 rng_file = os.path.join(checkpoint, "rng_state.pth")
3559 if not os.path.isfile(rng_file):
3560 logger.info(
3561 "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
3562 "fashion, reproducibility is not guaranteed."
3563 )
3564 return
3565
3566 with safe_globals():
3567 check_torch_load_is_safe()
3568 checkpoint_rng_state = torch.load(rng_file, weights_only=True)
3569 random.setstate(checkpoint_rng_state["python"])
3570 np.random.set_state(checkpoint_rng_state["numpy"])
3571 torch.random.set_rng_state(checkpoint_rng_state["cpu"])
3572 if is_torch_xla_available():
3573 xm.set_rng_state(checkpoint_rng_state["xla"])
3574
3575 is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
3576 if torch.cuda.is_available():
3577 set_rng_state_for_device("CUDA", torch.cuda, checkpoint_rng_state, is_distributed)
3578 if is_torch_npu_available():
3579 set_rng_state_for_device("NPU", torch.npu, checkpoint_rng_state, is_distributed)
3580 if is_torch_hpu_available():
3581 set_rng_state_for_device("HPU", torch.hpu, checkpoint_rng_state, is_distributed)
3582 if is_torch_mlu_available():
3583 set_rng_state_for_device("MLU", torch.mlu, checkpoint_rng_state, is_distributed)
3584 if is_torch_musa_available():
3585 set_rng_state_for_device("MUSA", torch.musa, checkpoint_rng_state, is_distributed)
3586
3587 def _load_optimizer_and_scheduler(self, checkpoint: str | None) -> None:
3588 """If optimizer and scheduler states exist, load them."""

Callers 1

_run_epochMethod · 0.95

Calls 11

safe_globalsFunction · 0.85
check_torch_load_is_safeFunction · 0.85
is_torch_xla_availableFunction · 0.85
set_rng_state_for_deviceFunction · 0.85
is_torch_npu_availableFunction · 0.85
is_torch_hpu_availableFunction · 0.85
is_torch_mlu_availableFunction · 0.85
is_torch_musa_availableFunction · 0.85
joinMethod · 0.80
infoMethod · 0.45
is_availableMethod · 0.45

Tested by

no test coverage detected