Check if the returned value from a handler is a factory rather than an Optimizer class. Factory callables are used for complex optimizers like Muon or Dion that need to: - Split parameters between multiple internal optimizers - Handle complex sharding logic - Access the full mo
(optimizer_cls_or_factory: Any)
| 78 | |
| 79 | |
| 80 | def is_optimizer_factory(optimizer_cls_or_factory: Any) -> bool: |
| 81 | """ |
| 82 | Check if the returned value from a handler is a factory rather than an Optimizer class. |
| 83 | |
| 84 | Factory callables are used for complex optimizers like Muon or Dion that need to: |
| 85 | - Split parameters between multiple internal optimizers |
| 86 | - Handle complex sharding logic |
| 87 | - Access the full model structure for parameter grouping |
| 88 | |
| 89 | Args: |
| 90 | optimizer_cls_or_factory: The first element returned by an optimizer handler. |
| 91 | |
| 92 | Returns: |
| 93 | `bool`: True if it's not an Optimizer class (i.e., likely a factory), False if it's an Optimizer class. |
| 94 | """ |
| 95 | # If it's a class that's a subclass of torch.optim.Optimizer, it's not a factory |
| 96 | if isinstance(optimizer_cls_or_factory, type) and issubclass(optimizer_cls_or_factory, torch.optim.Optimizer): |
| 97 | return False |
| 98 | return True |
| 99 | |
| 100 | |
| 101 | def _setup_low_rank_optimizer( |
no outgoing calls