MCPcopy
hub / github.com/huggingface/transformers / get_scheduler

Function get_scheduler

src/transformers/optimization.py:960–1054  ·  view source on GitHub ↗

Unified API to get any scheduler from its name. Args: name (`str` or `SchedulerType`): The name of the scheduler to use. optimizer (`torch.optim.Optimizer`): The optimizer that will be used during training. num_warmup_steps (`int`, *optional*

(
    name: str | SchedulerType,
    optimizer: Optimizer,
    num_warmup_steps: int | None = None,
    num_training_steps: int | None = None,
    scheduler_specific_kwargs: dict | None = None,
)

Source from the content-addressed store, hash-verified

958
959
960def get_scheduler(
961 name: str | SchedulerType,
962 optimizer: Optimizer,
963 num_warmup_steps: int | None = None,
964 num_training_steps: int | None = None,
965 scheduler_specific_kwargs: dict | None = None,
966):
967 """
968 Unified API to get any scheduler from its name.
969
970 Args:
971 name (`str` or `SchedulerType`):
972 The name of the scheduler to use.
973 optimizer (`torch.optim.Optimizer`):
974 The optimizer that will be used during training.
975 num_warmup_steps (`int`, *optional*):
976 The number of warmup steps to do. This is not required by all schedulers (hence the argument being
977 optional), the function will raise an error if it's unset and the scheduler type requires it.
978 num_training_steps (`int``, *optional*):
979 The number of training steps to do. This is not required by all schedulers (hence the argument being
980 optional), the function will raise an error if it's unset and the scheduler type requires it.
981 scheduler_specific_kwargs (`dict`, *optional*):
982 Extra parameters for schedulers such as cosine with restarts. Mismatched scheduler types and scheduler
983 parameters will cause the scheduler function to raise a TypeError.
984 """
985 name = SchedulerType(name)
986 schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
987
988 # If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and
989 # recursively call `get_scheduler` to get the proper schedulers on each parameter
990 if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer):
991 optimizer_dict = optimizer.optimizer_dict
992 scheduler_dict = {}
993
994 for param in optimizer_dict:
995 scheduler_dict[param] = get_scheduler(
996 name,
997 optimizer=optimizer_dict[param],
998 num_warmup_steps=num_warmup_steps,
999 num_training_steps=num_training_steps,
1000 scheduler_specific_kwargs=scheduler_specific_kwargs,
1001 )
1002
1003 def scheduler_hook(param):
1004 # Since the optimizer hook has been already attached we only need to
1005 # attach the scheduler hook, the gradients have been zeroed here
1006 scheduler_dict[param].step()
1007
1008 for param in optimizer_dict:
1009 if param.requires_grad:
1010 param.register_post_accumulate_grad_hook(scheduler_hook)
1011
1012 return LayerWiseDummyScheduler(optimizer_dict=optimizer_dict, lr=optimizer.defaults["lr"])
1013
1014 if name == SchedulerType.CONSTANT:
1015 return schedule_func(optimizer)
1016
1017 if scheduler_specific_kwargs is None:

Callers 15

test_get_schedulerMethod · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls 2

SchedulerTypeClass · 0.85