Perform a training step on a batch of inputs. Subclass and override to inject custom behavior. Args: model (`nn.Module`): The model to train. inputs (`dict[str, torch.Tensor | Any]`): The inputs and targets of the mod
(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: torch.Tensor | int | None = None,
)
| 1874 | return TrainOutput(self.state.global_step, train_loss, metrics) |
| 1875 | |
| 1876 | def training_step( |
| 1877 | self, |
| 1878 | model: nn.Module, |
| 1879 | inputs: dict[str, torch.Tensor | Any], |
| 1880 | num_items_in_batch: torch.Tensor | int | None = None, |
| 1881 | ) -> torch.Tensor: |
| 1882 | """ |
| 1883 | Perform a training step on a batch of inputs. |
| 1884 | |
| 1885 | Subclass and override to inject custom behavior. |
| 1886 | |
| 1887 | Args: |
| 1888 | model (`nn.Module`): |
| 1889 | The model to train. |
| 1890 | inputs (`dict[str, torch.Tensor | Any]`): |
| 1891 | The inputs and targets of the model. |
| 1892 | |
| 1893 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the |
| 1894 | argument `labels`. Check your model's documentation for all accepted arguments. |
| 1895 | |
| 1896 | Return: |
| 1897 | `torch.Tensor`: The tensor with training loss on this batch. |
| 1898 | """ |
| 1899 | # Prepare buffers for context parallelism |
| 1900 | |
| 1901 | cp_context, inputs = self._prepare_context_parallel_inputs(model, inputs) |
| 1902 | |
| 1903 | # Context manager is no-op if CP isn't enabled |
| 1904 | with cp_context(): |
| 1905 | model.train() |
| 1906 | if hasattr(self.optimizer, "train") and callable(self.optimizer.train): |
| 1907 | self.optimizer.train() |
| 1908 | |
| 1909 | inputs = self._prepare_inputs(inputs) |
| 1910 | if is_sagemaker_mp_enabled(): |
| 1911 | loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) |
| 1912 | return loss_mb.reduce_mean().detach().to(self.args.device) |
| 1913 | |
| 1914 | with self.compute_loss_context_manager(): |
| 1915 | loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) |
| 1916 | |
| 1917 | del inputs |
| 1918 | if ( |
| 1919 | self.args.torch_empty_cache_steps is not None |
| 1920 | and self.state.global_step % self.args.torch_empty_cache_steps == 0 |
| 1921 | ): |
| 1922 | clear_device_cache() |
| 1923 | |
| 1924 | kwargs = {} |
| 1925 | |
| 1926 | # For LOMO optimizers you need to explicitly use the learning rate |
| 1927 | if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: |
| 1928 | kwargs["learning_rate"] = self._get_learning_rate() |
| 1929 | |
| 1930 | if self.args.n_gpu > 1: |
| 1931 | loss = loss.mean() # mean() to average on multi-gpu parallel training |
| 1932 | |
| 1933 | # Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss |