Rolling window average for smoothing metric values. Maintains a sliding window of values and computes their average, useful for smoothing noisy metric values before making learning rate decisions. Args: window_size (`int`): The maximum number of values to keep in th
| 579 | |
| 580 | |
| 581 | class StreamingAverage: |
| 582 | """Rolling window average for smoothing metric values. |
| 583 | |
| 584 | Maintains a sliding window of values and computes their average, |
| 585 | useful for smoothing noisy metric values before making learning rate decisions. |
| 586 | |
| 587 | Args: |
| 588 | window_size (`int`): |
| 589 | The maximum number of values to keep in the rolling window. |
| 590 | """ |
| 591 | |
| 592 | def __init__(self, window_size: int) -> None: |
| 593 | self.window_size: int = window_size |
| 594 | self.values: list[float] = [] |
| 595 | self.sum: float = 0.0 |
| 596 | |
| 597 | def streamavg(self, value: float) -> float: |
| 598 | """Add a value and return the current rolling average.""" |
| 599 | self.values.append(value) |
| 600 | self.sum += value |
| 601 | |
| 602 | if len(self.values) > self.window_size: |
| 603 | removed = self.values.pop(0) |
| 604 | self.sum -= removed |
| 605 | |
| 606 | return self.sum / len(self.values) |
| 607 | |
| 608 | def state_dict(self) -> dict[str, Any]: |
| 609 | return { |
| 610 | "window_size": self.window_size, |
| 611 | "values": self.values.copy(), |
| 612 | "sum": self.sum, |
| 613 | } |
| 614 | |
| 615 | def load_state_dict(self, state_dict: dict[str, Any]) -> None: |
| 616 | self.window_size = state_dict.get("window_size", self.window_size) |
| 617 | self.values = state_dict.get("values", []).copy() |
| 618 | self.sum = state_dict.get("sum", 0.0) |
| 619 | |
| 620 | |
| 621 | class GreedyLR: |
no outgoing calls