MCPcopy
hub / github.com/huggingface/transformers / _hp_search_setup

Method _hp_search_setup

src/transformers/trainer.py:4279–4331  ·  view source on GitHub ↗

Set up training arguments and accelerator state for a hyperparameter search trial.

(self, trial: "optuna.Trial | dict[str, Any] | None")

Source from the content-addressed store, hash-verified

4277 return model
4278
4279 def _hp_search_setup(self, trial: "optuna.Trial | dict[str, Any] | None") -> None:
4280 """Set up training arguments and accelerator state for a hyperparameter search trial."""
4281 self._trial = trial
4282
4283 if self.hp_search_backend is None or trial is None:
4284 return
4285 if self.hp_search_backend == HPSearchBackend.OPTUNA:
4286 params = self.hp_space(trial)
4287 elif self.hp_search_backend == HPSearchBackend.RAY:
4288 params = trial
4289 params.pop("wandb", None)
4290 elif self.hp_search_backend == HPSearchBackend.WANDB:
4291 params = trial
4292
4293 for key, value in params.items():
4294 if not hasattr(self.args, key):
4295 logger.warning(
4296 f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
4297 " `TrainingArguments`."
4298 )
4299 continue
4300 old_attr = getattr(self.args, key, None)
4301 # Casting value to the proper type
4302 if old_attr is not None:
4303 value = type(old_attr)(value)
4304
4305 setattr(self.args, key, value)
4306 if self.hp_search_backend == HPSearchBackend.OPTUNA:
4307 logger.info(f"Trial: {trial.params}")
4308 if self.hp_search_backend == HPSearchBackend.WANDB:
4309 logger.info(f"W&B Sweep parameters: {trial}")
4310 if self.is_deepspeed_enabled:
4311 if self.args.deepspeed is None:
4312 raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set")
4313
4314 self.accelerator.free_memory()
4315
4316 # Rebuild the deepspeed config to reflect the updated training parameters
4317 from accelerate.utils import DeepSpeedPlugin
4318
4319 from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
4320
4321 self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
4322 self.args.hf_deepspeed_config.trainer_config_process(self.args)
4323 self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
4324
4325 # From 1.0 on, we need to fully wipe the DS plugin when doing sweeps.
4326 # Simply calling `_reset_state` is enough and doesn't need a version pin.
4327 AcceleratorState()._reset_state()
4328
4329 # `train_batch_size` might change when using HPO https://github.com/huggingface/transformers/pull/18918
4330 self._train_batch_size = self.args.train_batch_size
4331 self.create_accelerator_and_postprocess()
4332
4333 def _report_to_hp_search(
4334 self, trial: "optuna.Trial | dict[str, Any] | None", step: int, metrics: dict[str, float]

Callers 1

trainMethod · 0.95

Calls 8

warningMethod · 0.80
hp_spaceMethod · 0.45
popMethod · 0.45
itemsMethod · 0.45
infoMethod · 0.45

Tested by

no test coverage detected