Validate constructor arguments and fail fast on incompatible combinations.
(self)
| 618 | self._memory_tracker.stop_and_update_metrics() |
| 619 | |
| 620 | def _validate_args(self) -> None: |
| 621 | """Validate constructor arguments and fail fast on incompatible combinations.""" |
| 622 | args = self.args |
| 623 | |
| 624 | # --- SageMaker Model Parallel mixed-precision validation --- |
| 625 | if is_sagemaker_mp_enabled(): |
| 626 | if args.bf16: |
| 627 | raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ") |
| 628 | if args.fp16 != smp.state.cfg.fp16: |
| 629 | logger.warning( |
| 630 | f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " |
| 631 | f"but FP16 provided in trainer argument is {args.fp16}, " |
| 632 | f"setting to {smp.state.cfg.fp16}" |
| 633 | ) |
| 634 | args.fp16 = smp.state.cfg.fp16 |
| 635 | |
| 636 | # --- Training-argument validations --- |
| 637 | if args.batch_eval_metrics and self.compute_metrics is not None: |
| 638 | if "compute_result" not in inspect.signature(self.compute_metrics).parameters: |
| 639 | raise ValueError( |
| 640 | "When using `batch_eval_metrics`, your `compute_metrics` function must take a `compute_result`" |
| 641 | " boolean argument which will be triggered after the last batch of the eval set to signal that the" |
| 642 | " summary statistics should be returned by the function." |
| 643 | ) |
| 644 | if args.eval_strategy is not None and args.eval_strategy != "no" and self.eval_dataset is None: |
| 645 | raise ValueError( |
| 646 | f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. " |
| 647 | ) |
| 648 | if args.save_strategy == SaveStrategy.BEST or args.load_best_model_at_end: |
| 649 | if args.metric_for_best_model is None: |
| 650 | raise ValueError( |
| 651 | "`args.metric_for_best_model` must be provided when using 'best' save_strategy or if `args.load_best_model_at_end` is set to `True`." |
| 652 | ) |
| 653 | |
| 654 | # --- Optimizer validations --- |
| 655 | if self.optimizer_cls_and_kwargs is not None and self.optimizer is not None: |
| 656 | raise RuntimeError("Passing both `optimizers` and `optimizer_cls_and_kwargs` arguments is incompatible.") |
| 657 | if self.model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): |
| 658 | raise RuntimeError( |
| 659 | "Passing a `model_init` is incompatible with providing the `optimizers` argument. " |
| 660 | "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." |
| 661 | ) |
| 662 | if is_torch_xla_available() and self.optimizer is not None: |
| 663 | for param in self.model.parameters(): |
| 664 | model_device = param.device |
| 665 | break |
| 666 | for param_group in self.optimizer.param_groups: |
| 667 | if len(param_group["params"]) > 0: |
| 668 | optimizer_device = param_group["params"][0].device |
| 669 | break |
| 670 | if model_device != optimizer_device: |
| 671 | raise ValueError( |
| 672 | "The model and the optimizer parameters are not on the same device, which probably means you" |
| 673 | " created an optimizer around your model **before** putting on the device and passing it to the" |
| 674 | " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" |
| 675 | " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." |
| 676 | ) |
| 677 | if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and ( |
no test coverage detected