MCPcopy
hub / github.com/huggingface/transformers / create_optimizer

Method create_optimizer

src/transformers/trainer.py:1152–1226  ·  view source on GitHub ↗

Setup the optimizer. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method in a subclass. Returns: `torch.optim.Optimizer

(self, model=None)

Source from the content-addressed store, hash-verified

1150 self.create_scheduler(num_training_steps=num_training_steps)
1151
1152 def create_optimizer(self, model=None) -> torch.optim.Optimizer:
1153 """
1154 Setup the optimizer.
1155
1156 We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
1157 Trainer's init through `optimizers`, or subclass and override this method in a subclass.
1158
1159 Returns:
1160 `torch.optim.Optimizer`: The optimizer instance.
1161 """
1162 opt_model = self.model if model is None else model
1163
1164 if self.optimizer is None:
1165 decay_parameters = self.get_decay_parameter_names(opt_model)
1166 optimizer_grouped_parameters = [
1167 {
1168 "params": [
1169 p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
1170 ],
1171 "weight_decay": self.args.weight_decay,
1172 },
1173 {
1174 "params": [
1175 p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
1176 ],
1177 "weight_decay": 0.0,
1178 },
1179 ]
1180
1181 if self.optimizer_cls_and_kwargs is not None:
1182 optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
1183 else:
1184 optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
1185
1186 # Check if this is a factory (for complex optimizers like Muon, Dion)
1187 # Factories are instantiated first, then called with (opt_model, **kwargs)
1188 if is_optimizer_factory(optimizer_cls):
1189 self.optimizer = optimizer_cls()(opt_model, **optimizer_kwargs)
1190 else:
1191 # Standard optimizer class instantiation
1192 # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
1193 # e.g. for GaLore optimizer.
1194 if "params" in optimizer_kwargs:
1195 optimizer_grouped_parameters = optimizer_kwargs.pop("params")
1196
1197 # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
1198 # e.g. for LOMO optimizer.
1199 if "model" in optimizer_kwargs:
1200 optimizer_grouped_parameters = optimizer_kwargs.pop("model")
1201
1202 # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
1203 # to avoid arguments conflicts.
1204 if "optimizer_dict" in optimizer_kwargs:
1205 optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
1206
1207 self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
1208
1209 if "bitsandbytes" in str(optimizer_cls) and optimizer_kwargs.get("optim_bits", None) == 8:

Calls 9

is_optimizer_factoryFunction · 0.85
is_sagemaker_mp_enabledFunction · 0.85
popMethod · 0.45
getMethod · 0.45
valuesMethod · 0.45
infoMethod · 0.45
debugMethod · 0.45