MCPcopy
hub / github.com/huggingface/transformers / _save_rng_state

Method _save_rng_state

src/transformers/trainer.py:3167–3216  ·  view source on GitHub ↗

Save random number generator states for reproducible resumption.

(self, output_dir: str)

Source from the content-addressed store, hash-verified

3165 return is_new_best_metric
3166
3167 def _save_rng_state(self, output_dir: str) -> None:
3168 """Save random number generator states for reproducible resumption."""
3169 # Save RNG state in non-distributed training
3170 rng_states = {
3171 "python": random.getstate(),
3172 "numpy": np.random.get_state(),
3173 "cpu": torch.random.get_rng_state(),
3174 }
3175 if torch.cuda.is_available():
3176 if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
3177 # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
3178 rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
3179 else:
3180 rng_states["cuda"] = torch.cuda.random.get_rng_state()
3181
3182 if is_torch_xla_available():
3183 rng_states["xla"] = xm.get_rng_state()
3184
3185 if is_torch_npu_available():
3186 if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
3187 rng_states["npu"] = torch.npu.random.get_rng_state_all()
3188 else:
3189 rng_states["npu"] = torch.npu.random.get_rng_state()
3190
3191 if is_torch_hpu_available():
3192 if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
3193 rng_states["hpu"] = torch.hpu.random.get_rng_state_all()
3194 else:
3195 rng_states["hpu"] = torch.hpu.random.get_rng_state()
3196
3197 if is_torch_mlu_available():
3198 if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
3199 rng_states["mlu"] = torch.mlu.random.get_rng_state_all()
3200 else:
3201 rng_states["mlu"] = torch.mlu.random.get_rng_state()
3202
3203 if is_torch_musa_available():
3204 if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
3205 rng_states["musa"] = torch.musa.get_rng_state_all()
3206 else:
3207 rng_states["musa"] = torch.musa.get_rng_state()
3208
3209 # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
3210 # not yet exist.
3211 os.makedirs(output_dir, exist_ok=True)
3212
3213 if self.args.world_size <= 1:
3214 torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
3215 else:
3216 torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
3217
3218 def _save_optimizer_and_scheduler(self, output_dir: str) -> None:
3219 """Save optimizer and learning rate scheduler states to `output_dir`."""

Callers 1

_save_checkpointMethod · 0.95

Calls 8

is_torch_xla_availableFunction · 0.85
is_torch_npu_availableFunction · 0.85
is_torch_hpu_availableFunction · 0.85
is_torch_mlu_availableFunction · 0.85
is_torch_musa_availableFunction · 0.85
joinMethod · 0.80
is_availableMethod · 0.45
saveMethod · 0.45

Tested by

no test coverage detected