Log metrics, run evaluation, and save checkpoints if the current training state requires it.
(
self,
tr_loss: torch.Tensor,
grad_norm: torch.Tensor | float | None,
model: nn.Module,
trial: "optuna.Trial | dict[str, Any] | None",
epoch: float,
ignore_keys_for_eval: list[str] | None,
start_time: float,
learning_rate: float | None = None,
)
| 2053 | return contextlib.nullcontext() |
| 2054 | |
| 2055 | def _maybe_log_save_evaluate( |
| 2056 | self, |
| 2057 | tr_loss: torch.Tensor, |
| 2058 | grad_norm: torch.Tensor | float | None, |
| 2059 | model: nn.Module, |
| 2060 | trial: "optuna.Trial | dict[str, Any] | None", |
| 2061 | epoch: float, |
| 2062 | ignore_keys_for_eval: list[str] | None, |
| 2063 | start_time: float, |
| 2064 | learning_rate: float | None = None, |
| 2065 | ) -> None: |
| 2066 | """Log metrics, run evaluation, and save checkpoints if the current training state requires it.""" |
| 2067 | if self.control.should_log and self.state.global_step > self._globalstep_last_logged: |
| 2068 | if is_torch_xla_available(): |
| 2069 | xm.mark_step() |
| 2070 | |
| 2071 | logs: dict[str, float] = {} |
| 2072 | |
| 2073 | # all_gather + mean() to get average loss over all processes |
| 2074 | tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item() |
| 2075 | |
| 2076 | # reset tr_loss to zero |
| 2077 | tr_loss -= tr_loss |
| 2078 | |
| 2079 | logs["loss"] = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged) |
| 2080 | if grad_norm is not None: |
| 2081 | logs["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm |
| 2082 | if learning_rate is not None: |
| 2083 | logs["learning_rate"] = learning_rate |
| 2084 | else: |
| 2085 | logs["learning_rate"] = self._get_learning_rate() |
| 2086 | |
| 2087 | self._total_loss_scalar += tr_loss_scalar |
| 2088 | self._globalstep_last_logged = self.state.global_step |
| 2089 | self.store_flos() |
| 2090 | |
| 2091 | self.log(logs, start_time) |
| 2092 | |
| 2093 | metrics = None |
| 2094 | if self.control.should_evaluate: |
| 2095 | metrics = self._evaluate(trial, ignore_keys_for_eval) |
| 2096 | is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial) |
| 2097 | |
| 2098 | if self.args.save_strategy == SaveStrategy.BEST: |
| 2099 | self.control.should_save = is_new_best_metric |
| 2100 | |
| 2101 | if self.control.should_save: |
| 2102 | self._save_checkpoint(model, trial) |
| 2103 | self.control = self.callback_handler.on_save(self.args, self.state, self.control) |
| 2104 | |
| 2105 | # ---- Training Utilites ---- |
| 2106 | def get_batch_samples( |
no test coverage detected