Returns the training [`~torch.utils.data.DataLoader`]. Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed training if necessary) otherwise. Subclass and override this method if you want to inject some custo
(self)
| 869 | # ---- Data Loading ---- |
| 870 | |
| 871 | def get_train_dataloader(self) -> DataLoader: |
| 872 | """ |
| 873 | Returns the training [`~torch.utils.data.DataLoader`]. |
| 874 | |
| 875 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed |
| 876 | training if necessary) otherwise. |
| 877 | |
| 878 | Subclass and override this method if you want to inject some custom behavior. |
| 879 | """ |
| 880 | if self.train_dataset is None: |
| 881 | raise ValueError("Trainer: training requires a train_dataset.") |
| 882 | |
| 883 | return self._get_dataloader( |
| 884 | dataset=self.train_dataset, |
| 885 | description="Training", |
| 886 | batch_size=self._train_batch_size, |
| 887 | sampler_fn=self._get_train_sampler, |
| 888 | is_training=True, |
| 889 | ) |
| 890 | |
| 891 | def get_eval_dataloader(self, eval_dataset: str | Dataset | None = None) -> DataLoader: |
| 892 | """ |