Save a checkpoint during a Ray Tune hyperparameter search trial.
(self, checkpoint_dir: str)
| 4358 | ray.tune.report(metrics, checkpoint=checkpoint) |
| 4359 | |
| 4360 | def _tune_save_checkpoint(self, checkpoint_dir: str) -> None: |
| 4361 | """Save a checkpoint during a Ray Tune hyperparameter search trial.""" |
| 4362 | output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") |
| 4363 | self.save_model(output_dir, _internal_call=True) |
| 4364 | if self.args.should_save: |
| 4365 | # Update the `TrainerControl` state to where we are currently |
| 4366 | self.state.stateful_callbacks["TrainerControl"] = self.control.state() |
| 4367 | self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) |
| 4368 | torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) |
| 4369 | torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) |
| 4370 | |
| 4371 | # ---- Callbacks ---- |
| 4372 |
no test coverage detected