Set up training arguments and accelerator state for a hyperparameter search trial.
(self, trial: "optuna.Trial | dict[str, Any] | None")
| 4277 | return model |
| 4278 | |
| 4279 | def _hp_search_setup(self, trial: "optuna.Trial | dict[str, Any] | None") -> None: |
| 4280 | """Set up training arguments and accelerator state for a hyperparameter search trial.""" |
| 4281 | self._trial = trial |
| 4282 | |
| 4283 | if self.hp_search_backend is None or trial is None: |
| 4284 | return |
| 4285 | if self.hp_search_backend == HPSearchBackend.OPTUNA: |
| 4286 | params = self.hp_space(trial) |
| 4287 | elif self.hp_search_backend == HPSearchBackend.RAY: |
| 4288 | params = trial |
| 4289 | params.pop("wandb", None) |
| 4290 | elif self.hp_search_backend == HPSearchBackend.WANDB: |
| 4291 | params = trial |
| 4292 | |
| 4293 | for key, value in params.items(): |
| 4294 | if not hasattr(self.args, key): |
| 4295 | logger.warning( |
| 4296 | f"Trying to set {key} in the hyperparameter search but there is no corresponding field in" |
| 4297 | " `TrainingArguments`." |
| 4298 | ) |
| 4299 | continue |
| 4300 | old_attr = getattr(self.args, key, None) |
| 4301 | # Casting value to the proper type |
| 4302 | if old_attr is not None: |
| 4303 | value = type(old_attr)(value) |
| 4304 | |
| 4305 | setattr(self.args, key, value) |
| 4306 | if self.hp_search_backend == HPSearchBackend.OPTUNA: |
| 4307 | logger.info(f"Trial: {trial.params}") |
| 4308 | if self.hp_search_backend == HPSearchBackend.WANDB: |
| 4309 | logger.info(f"W&B Sweep parameters: {trial}") |
| 4310 | if self.is_deepspeed_enabled: |
| 4311 | if self.args.deepspeed is None: |
| 4312 | raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set") |
| 4313 | |
| 4314 | self.accelerator.free_memory() |
| 4315 | |
| 4316 | # Rebuild the deepspeed config to reflect the updated training parameters |
| 4317 | from accelerate.utils import DeepSpeedPlugin |
| 4318 | |
| 4319 | from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig |
| 4320 | |
| 4321 | self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed) |
| 4322 | self.args.hf_deepspeed_config.trainer_config_process(self.args) |
| 4323 | self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config) |
| 4324 | |
| 4325 | # From 1.0 on, we need to fully wipe the DS plugin when doing sweeps. |
| 4326 | # Simply calling `_reset_state` is enough and doesn't need a version pin. |
| 4327 | AcceleratorState()._reset_state() |
| 4328 | |
| 4329 | # `train_batch_size` might change when using HPO https://github.com/huggingface/transformers/pull/18918 |
| 4330 | self._train_batch_size = self.args.train_batch_size |
| 4331 | self.create_accelerator_and_postprocess() |
| 4332 | |
| 4333 | def _report_to_hp_search( |
| 4334 | self, trial: "optuna.Trial | dict[str, Any] | None", step: int, metrics: dict[str, float] |
no test coverage detected