MCPcopy
hub / github.com/huggingface/transformers / create_scheduler

Method create_scheduler

src/transformers/trainer.py:1228–1256  ·  view source on GitHub ↗

Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or passed as an argument. Args: num_training_steps (int): The number of training steps to do. Returns: `torch.optim.lr_scheduler.LRSc

(
        self, num_training_steps: int, optimizer: torch.optim.Optimizer | None = None
    )

Source from the content-addressed store, hash-verified

1226 return self.optimizer
1227
1228 def create_scheduler(
1229 self, num_training_steps: int, optimizer: torch.optim.Optimizer | None = None
1230 ) -> torch.optim.lr_scheduler.LRScheduler:
1231 """
1232 Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
1233 passed as an argument.
1234
1235 Args:
1236 num_training_steps (int): The number of training steps to do.
1237
1238 Returns:
1239 `torch.optim.lr_scheduler.LRScheduler`: The learning rate scheduler instance.
1240 """
1241 if self.lr_scheduler is None:
1242 if optimizer is None:
1243 if is_sagemaker_mp_enabled() and smp.state.cfg.fp16:
1244 # If fp16 is enabled, we unwrap the optimizer
1245 optimizer = self.optimizer.optimizer
1246 else:
1247 optimizer = self.optimizer
1248 self.lr_scheduler = get_scheduler(
1249 self.args.lr_scheduler_type,
1250 optimizer=optimizer,
1251 num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
1252 num_training_steps=num_training_steps,
1253 scheduler_specific_kwargs=self.args.lr_scheduler_kwargs,
1254 )
1255 self._created_lr_scheduler = True
1256 return self.lr_scheduler
1257
1258 @staticmethod
1259 def get_optimizer_cls_and_kwargs(args: TrainingArguments, model: PreTrainedModel | None = None) -> tuple[Any, Any]:

Callers 3

_prepare_for_trainingMethod · 0.95
_lr_scheduler_callableFunction · 0.80

Calls 3

is_sagemaker_mp_enabledFunction · 0.85
get_warmup_stepsMethod · 0.80
get_schedulerFunction · 0.70

Tested by

no test coverage detected