Save random number generator states for reproducible resumption.
(self, output_dir: str)
| 3165 | return is_new_best_metric |
| 3166 | |
| 3167 | def _save_rng_state(self, output_dir: str) -> None: |
| 3168 | """Save random number generator states for reproducible resumption.""" |
| 3169 | # Save RNG state in non-distributed training |
| 3170 | rng_states = { |
| 3171 | "python": random.getstate(), |
| 3172 | "numpy": np.random.get_state(), |
| 3173 | "cpu": torch.random.get_rng_state(), |
| 3174 | } |
| 3175 | if torch.cuda.is_available(): |
| 3176 | if self.args.parallel_mode == ParallelMode.DISTRIBUTED: |
| 3177 | # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) |
| 3178 | rng_states["cuda"] = torch.cuda.random.get_rng_state_all() |
| 3179 | else: |
| 3180 | rng_states["cuda"] = torch.cuda.random.get_rng_state() |
| 3181 | |
| 3182 | if is_torch_xla_available(): |
| 3183 | rng_states["xla"] = xm.get_rng_state() |
| 3184 | |
| 3185 | if is_torch_npu_available(): |
| 3186 | if self.args.parallel_mode == ParallelMode.DISTRIBUTED: |
| 3187 | rng_states["npu"] = torch.npu.random.get_rng_state_all() |
| 3188 | else: |
| 3189 | rng_states["npu"] = torch.npu.random.get_rng_state() |
| 3190 | |
| 3191 | if is_torch_hpu_available(): |
| 3192 | if self.args.parallel_mode == ParallelMode.DISTRIBUTED: |
| 3193 | rng_states["hpu"] = torch.hpu.random.get_rng_state_all() |
| 3194 | else: |
| 3195 | rng_states["hpu"] = torch.hpu.random.get_rng_state() |
| 3196 | |
| 3197 | if is_torch_mlu_available(): |
| 3198 | if self.args.parallel_mode == ParallelMode.DISTRIBUTED: |
| 3199 | rng_states["mlu"] = torch.mlu.random.get_rng_state_all() |
| 3200 | else: |
| 3201 | rng_states["mlu"] = torch.mlu.random.get_rng_state() |
| 3202 | |
| 3203 | if is_torch_musa_available(): |
| 3204 | if self.args.parallel_mode == ParallelMode.DISTRIBUTED: |
| 3205 | rng_states["musa"] = torch.musa.get_rng_state_all() |
| 3206 | else: |
| 3207 | rng_states["musa"] = torch.musa.get_rng_state() |
| 3208 | |
| 3209 | # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may |
| 3210 | # not yet exist. |
| 3211 | os.makedirs(output_dir, exist_ok=True) |
| 3212 | |
| 3213 | if self.args.world_size <= 1: |
| 3214 | torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) |
| 3215 | else: |
| 3216 | torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) |
| 3217 | |
| 3218 | def _save_optimizer_and_scheduler(self, output_dir: str) -> None: |
| 3219 | """Save optimizer and learning rate scheduler states to `output_dir`.""" |
no test coverage detected