Setup the optimizer. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method in a subclass. Returns: `torch.optim.Optimizer
(self, model=None)
| 1150 | self.create_scheduler(num_training_steps=num_training_steps) |
| 1151 | |
| 1152 | def create_optimizer(self, model=None) -> torch.optim.Optimizer: |
| 1153 | """ |
| 1154 | Setup the optimizer. |
| 1155 | |
| 1156 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the |
| 1157 | Trainer's init through `optimizers`, or subclass and override this method in a subclass. |
| 1158 | |
| 1159 | Returns: |
| 1160 | `torch.optim.Optimizer`: The optimizer instance. |
| 1161 | """ |
| 1162 | opt_model = self.model if model is None else model |
| 1163 | |
| 1164 | if self.optimizer is None: |
| 1165 | decay_parameters = self.get_decay_parameter_names(opt_model) |
| 1166 | optimizer_grouped_parameters = [ |
| 1167 | { |
| 1168 | "params": [ |
| 1169 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) |
| 1170 | ], |
| 1171 | "weight_decay": self.args.weight_decay, |
| 1172 | }, |
| 1173 | { |
| 1174 | "params": [ |
| 1175 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) |
| 1176 | ], |
| 1177 | "weight_decay": 0.0, |
| 1178 | }, |
| 1179 | ] |
| 1180 | |
| 1181 | if self.optimizer_cls_and_kwargs is not None: |
| 1182 | optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs |
| 1183 | else: |
| 1184 | optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model) |
| 1185 | |
| 1186 | # Check if this is a factory (for complex optimizers like Muon, Dion) |
| 1187 | # Factories are instantiated first, then called with (opt_model, **kwargs) |
| 1188 | if is_optimizer_factory(optimizer_cls): |
| 1189 | self.optimizer = optimizer_cls()(opt_model, **optimizer_kwargs) |
| 1190 | else: |
| 1191 | # Standard optimizer class instantiation |
| 1192 | # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` |
| 1193 | # e.g. for GaLore optimizer. |
| 1194 | if "params" in optimizer_kwargs: |
| 1195 | optimizer_grouped_parameters = optimizer_kwargs.pop("params") |
| 1196 | |
| 1197 | # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs` |
| 1198 | # e.g. for LOMO optimizer. |
| 1199 | if "model" in optimizer_kwargs: |
| 1200 | optimizer_grouped_parameters = optimizer_kwargs.pop("model") |
| 1201 | |
| 1202 | # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` |
| 1203 | # to avoid arguments conflicts. |
| 1204 | if "optimizer_dict" in optimizer_kwargs: |
| 1205 | optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict") |
| 1206 | |
| 1207 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) |
| 1208 | |
| 1209 | if "bitsandbytes" in str(optimizer_cls) and optimizer_kwargs.get("optim_bits", None) == 8: |