Initialize TrainerState, optionally restoring from checkpoint. Returns (epochs_trained, steps_trained_in_current_epoch).
(
self, max_steps, num_update_steps_per_epoch, num_train_epochs, resume_from_checkpoint, trial
)
| 1531 | return self._finalize_training(trial, num_train_samples, start_time) |
| 1532 | |
| 1533 | def _init_training_state( |
| 1534 | self, max_steps, num_update_steps_per_epoch, num_train_epochs, resume_from_checkpoint, trial |
| 1535 | ) -> tuple[int, int]: |
| 1536 | """Initialize TrainerState, optionally restoring from checkpoint. Returns (epochs_trained, steps_trained_in_current_epoch).""" |
| 1537 | self.state = TrainerState( |
| 1538 | stateful_callbacks=[ |
| 1539 | cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) |
| 1540 | ] |
| 1541 | ) |
| 1542 | self.state.is_hyper_param_search = trial is not None |
| 1543 | self.state.train_batch_size = self._train_batch_size |
| 1544 | self.state.compute_steps(self.args, max_steps) |
| 1545 | |
| 1546 | epochs_trained = 0 |
| 1547 | steps_trained_in_current_epoch = 0 |
| 1548 | |
| 1549 | if resume_from_checkpoint is not None and os.path.isfile( |
| 1550 | os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) |
| 1551 | ): |
| 1552 | self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) |
| 1553 | compare_trainer_and_checkpoint_args(self.args, self.state) |
| 1554 | self._load_callback_state() |
| 1555 | epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) |
| 1556 | if not self.args.ignore_data_skip: |
| 1557 | steps_trained_in_current_epoch = self.state.global_step % num_update_steps_per_epoch |
| 1558 | steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps |
| 1559 | |
| 1560 | self.state.init_training_references(self, max_steps, num_train_epochs, trial) |
| 1561 | |
| 1562 | return epochs_trained, steps_trained_in_current_epoch |
| 1563 | |
| 1564 | def _prepare_for_training(self, max_steps, train_dataloader, resume_from_checkpoint): |
| 1565 | """Wrap model, create optimizer and scheduler, and run accelerator.prepare. Returns (model, train_dataloader).""" |
no test coverage detected