MCPcopy
hub / github.com/huggingface/transformers / training_step

Method training_step

src/transformers/trainer.py:1876–1945  ·  view source on GitHub ↗

Perform a training step on a batch of inputs. Subclass and override to inject custom behavior. Args: model (`nn.Module`): The model to train. inputs (`dict[str, torch.Tensor | Any]`): The inputs and targets of the mod

(
        self,
        model: nn.Module,
        inputs: dict[str, torch.Tensor | Any],
        num_items_in_batch: torch.Tensor | int | None = None,
    )

Source from the content-addressed store, hash-verified

1874 return TrainOutput(self.state.global_step, train_loss, metrics)
1875
1876 def training_step(
1877 self,
1878 model: nn.Module,
1879 inputs: dict[str, torch.Tensor | Any],
1880 num_items_in_batch: torch.Tensor | int | None = None,
1881 ) -> torch.Tensor:
1882 """
1883 Perform a training step on a batch of inputs.
1884
1885 Subclass and override to inject custom behavior.
1886
1887 Args:
1888 model (`nn.Module`):
1889 The model to train.
1890 inputs (`dict[str, torch.Tensor | Any]`):
1891 The inputs and targets of the model.
1892
1893 The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
1894 argument `labels`. Check your model's documentation for all accepted arguments.
1895
1896 Return:
1897 `torch.Tensor`: The tensor with training loss on this batch.
1898 """
1899 # Prepare buffers for context parallelism
1900
1901 cp_context, inputs = self._prepare_context_parallel_inputs(model, inputs)
1902
1903 # Context manager is no-op if CP isn't enabled
1904 with cp_context():
1905 model.train()
1906 if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
1907 self.optimizer.train()
1908
1909 inputs = self._prepare_inputs(inputs)
1910 if is_sagemaker_mp_enabled():
1911 loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
1912 return loss_mb.reduce_mean().detach().to(self.args.device)
1913
1914 with self.compute_loss_context_manager():
1915 loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
1916
1917 del inputs
1918 if (
1919 self.args.torch_empty_cache_steps is not None
1920 and self.state.global_step % self.args.torch_empty_cache_steps == 0
1921 ):
1922 clear_device_cache()
1923
1924 kwargs = {}
1925
1926 # For LOMO optimizers you need to explicitly use the learning rate
1927 if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
1928 kwargs["learning_rate"] = self._get_learning_rate()
1929
1930 if self.args.n_gpu > 1:
1931 loss = loss.mean() # mean() to average on multi-gpu parallel training
1932
1933 # Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss

Callers 2

_run_epochMethod · 0.95

Calls 12

_prepare_inputsMethod · 0.95
compute_lossMethod · 0.95
_get_learning_rateMethod · 0.95
is_sagemaker_mp_enabledFunction · 0.85
smp_forward_backwardFunction · 0.85
detachMethod · 0.80
trainMethod · 0.45
toMethod · 0.45
meanMethod · 0.45
backwardMethod · 0.45

Tested by 1