| 235 | group['counter'] = 0 |
| 236 | |
| 237 | def update(self, group): |
| 238 | for fast_p in group['params']: |
| 239 | if fast_p.grad is None: |
| 240 | continue |
| 241 | param_state = self.state[fast_p] |
| 242 | if 'slow_buffer' not in param_state: |
| 243 | param_state['slow_buffer'] = torch.empty_like(fast_p.data) |
| 244 | param_state['slow_buffer'].copy_(fast_p.data) |
| 245 | |
| 246 | slow = param_state['slow_buffer'] |
| 247 | slow.add_(fast_p.data - slow, alpha=self.alpha) |
| 248 | fast_p.data.copy_(slow) |
| 249 | |
| 250 | def step(self, closure=None): |
| 251 | loss = self.optimizer.step(closure) |