| 3 | import math |
| 4 | |
| 5 | class AdamW(Optimizer): |
| 6 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01, amsgrad=False): |
| 7 | if lr < 0.0: |
| 8 | raise ValueError(f"Invalid learning rate: {lr}") |
| 9 | if eps < 0.0: |
| 10 | raise ValueError(f"Invalid epsilon value: {eps}") |
| 11 | if not 0.0 <= betas[0] < 1.0: |
| 12 | raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") |
| 13 | if not 0.0 <= betas[1] < 1.0: |
| 14 | raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") |
| 15 | |
| 16 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) |
| 17 | super().__init__(params, defaults) |
| 18 | |
| 19 | def step(self, closure=None): |
| 20 | loss = None |
| 21 | if closure is not None: |
| 22 | loss = closure() |
| 23 | |
| 24 | for group in self.param_groups: |
| 25 | for p in group['params']: |
| 26 | if p.grad is None: |
| 27 | continue |
| 28 | |
| 29 | grad = p.grad.data |
| 30 | if grad.is_sparse: |
| 31 | raise RuntimeError('AdamW does not support sparse gradients') |
| 32 | |
| 33 | amsgrad = group['amsgrad'] |
| 34 | state = self.state[p] |
| 35 | |
| 36 | if len(state) == 0: |
| 37 | state['step'] = 0 |
| 38 | state['exp_avg'] = torch.zeros_like(p.data) |
| 39 | state['exp_avg_sq'] = torch.zeros_like(p.data) |
| 40 | if amsgrad: |
| 41 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) |
| 42 | |
| 43 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] |
| 44 | if amsgrad: |
| 45 | max_exp_avg_sq = state['max_exp_avg_sq'] |
| 46 | beta1, beta2 = group['betas'] |
| 47 | |
| 48 | state['step'] += 1 |
| 49 | |
| 50 | p.data.mul_(1 - group['lr'] * group['weight_decay']) |
| 51 | |
| 52 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) |
| 53 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) |
| 54 | |
| 55 | if amsgrad: |
| 56 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) |
| 57 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) |
| 58 | else: |
| 59 | denom = exp_avg_sq.sqrt().add_(group['eps']) |
| 60 | |
| 61 | bias_correction1 = 1 - beta1 ** state['step'] |
| 62 | bias_correction2 = 1 - beta2 ** state['step'] |