MCPcopy
hub / github.com/huggingface/transformers / _maybe_log_save_evaluate

Method _maybe_log_save_evaluate

src/transformers/trainer.py:2055–2103  ·  view source on GitHub ↗

Log metrics, run evaluation, and save checkpoints if the current training state requires it.

(
        self,
        tr_loss: torch.Tensor,
        grad_norm: torch.Tensor | float | None,
        model: nn.Module,
        trial: "optuna.Trial | dict[str, Any] | None",
        epoch: float,
        ignore_keys_for_eval: list[str] | None,
        start_time: float,
        learning_rate: float | None = None,
    )

Source from the content-addressed store, hash-verified

2053 return contextlib.nullcontext()
2054
2055 def _maybe_log_save_evaluate(
2056 self,
2057 tr_loss: torch.Tensor,
2058 grad_norm: torch.Tensor | float | None,
2059 model: nn.Module,
2060 trial: "optuna.Trial | dict[str, Any] | None",
2061 epoch: float,
2062 ignore_keys_for_eval: list[str] | None,
2063 start_time: float,
2064 learning_rate: float | None = None,
2065 ) -> None:
2066 """Log metrics, run evaluation, and save checkpoints if the current training state requires it."""
2067 if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
2068 if is_torch_xla_available():
2069 xm.mark_step()
2070
2071 logs: dict[str, float] = {}
2072
2073 # all_gather + mean() to get average loss over all processes
2074 tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item()
2075
2076 # reset tr_loss to zero
2077 tr_loss -= tr_loss
2078
2079 logs["loss"] = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged)
2080 if grad_norm is not None:
2081 logs["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
2082 if learning_rate is not None:
2083 logs["learning_rate"] = learning_rate
2084 else:
2085 logs["learning_rate"] = self._get_learning_rate()
2086
2087 self._total_loss_scalar += tr_loss_scalar
2088 self._globalstep_last_logged = self.state.global_step
2089 self.store_flos()
2090
2091 self.log(logs, start_time)
2092
2093 metrics = None
2094 if self.control.should_evaluate:
2095 metrics = self._evaluate(trial, ignore_keys_for_eval)
2096 is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
2097
2098 if self.args.save_strategy == SaveStrategy.BEST:
2099 self.control.should_save = is_new_best_metric
2100
2101 if self.control.should_save:
2102 self._save_checkpoint(model, trial)
2103 self.control = self.callback_handler.on_save(self.args, self.state, self.control)
2104
2105 # ---- Training Utilites ----
2106 def get_batch_samples(

Callers 1

_run_epochMethod · 0.95

Calls 10

_get_learning_rateMethod · 0.95
store_flosMethod · 0.95
logMethod · 0.95
_evaluateMethod · 0.95
_save_checkpointMethod · 0.95
is_torch_xla_availableFunction · 0.85
nested_gatherFunction · 0.85
meanMethod · 0.45
on_saveMethod · 0.45

Tested by

no test coverage detected