Unified API to get any scheduler from its name. Args: name (`str` or `SchedulerType`): The name of the scheduler to use. optimizer (`torch.optim.Optimizer`): The optimizer that will be used during training. num_warmup_steps (`int`, *optional*
(
name: str | SchedulerType,
optimizer: Optimizer,
num_warmup_steps: int | None = None,
num_training_steps: int | None = None,
scheduler_specific_kwargs: dict | None = None,
)
| 958 | |
| 959 | |
| 960 | def get_scheduler( |
| 961 | name: str | SchedulerType, |
| 962 | optimizer: Optimizer, |
| 963 | num_warmup_steps: int | None = None, |
| 964 | num_training_steps: int | None = None, |
| 965 | scheduler_specific_kwargs: dict | None = None, |
| 966 | ): |
| 967 | """ |
| 968 | Unified API to get any scheduler from its name. |
| 969 | |
| 970 | Args: |
| 971 | name (`str` or `SchedulerType`): |
| 972 | The name of the scheduler to use. |
| 973 | optimizer (`torch.optim.Optimizer`): |
| 974 | The optimizer that will be used during training. |
| 975 | num_warmup_steps (`int`, *optional*): |
| 976 | The number of warmup steps to do. This is not required by all schedulers (hence the argument being |
| 977 | optional), the function will raise an error if it's unset and the scheduler type requires it. |
| 978 | num_training_steps (`int``, *optional*): |
| 979 | The number of training steps to do. This is not required by all schedulers (hence the argument being |
| 980 | optional), the function will raise an error if it's unset and the scheduler type requires it. |
| 981 | scheduler_specific_kwargs (`dict`, *optional*): |
| 982 | Extra parameters for schedulers such as cosine with restarts. Mismatched scheduler types and scheduler |
| 983 | parameters will cause the scheduler function to raise a TypeError. |
| 984 | """ |
| 985 | name = SchedulerType(name) |
| 986 | schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] |
| 987 | |
| 988 | # If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and |
| 989 | # recursively call `get_scheduler` to get the proper schedulers on each parameter |
| 990 | if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer): |
| 991 | optimizer_dict = optimizer.optimizer_dict |
| 992 | scheduler_dict = {} |
| 993 | |
| 994 | for param in optimizer_dict: |
| 995 | scheduler_dict[param] = get_scheduler( |
| 996 | name, |
| 997 | optimizer=optimizer_dict[param], |
| 998 | num_warmup_steps=num_warmup_steps, |
| 999 | num_training_steps=num_training_steps, |
| 1000 | scheduler_specific_kwargs=scheduler_specific_kwargs, |
| 1001 | ) |
| 1002 | |
| 1003 | def scheduler_hook(param): |
| 1004 | # Since the optimizer hook has been already attached we only need to |
| 1005 | # attach the scheduler hook, the gradients have been zeroed here |
| 1006 | scheduler_dict[param].step() |
| 1007 | |
| 1008 | for param in optimizer_dict: |
| 1009 | if param.requires_grad: |
| 1010 | param.register_post_accumulate_grad_hook(scheduler_hook) |
| 1011 | |
| 1012 | return LayerWiseDummyScheduler(optimizer_dict=optimizer_dict, lr=optimizer.defaults["lr"]) |
| 1013 | |
| 1014 | if name == SchedulerType.CONSTANT: |
| 1015 | return schedule_func(optimizer) |
| 1016 | |
| 1017 | if scheduler_specific_kwargs is None: |