Load state from a dictionary.
(self, state_dict: dict[str, Any])
| 890 | return state |
| 891 | |
| 892 | def load_state_dict(self, state_dict: dict[str, Any]) -> None: |
| 893 | """Load state from a dictionary.""" |
| 894 | self.factor = state_dict.get("factor", self.factor) |
| 895 | self.min_lrs = state_dict.get("min_lrs", self.min_lrs) |
| 896 | self.max_lrs = state_dict.get("max_lrs", self.max_lrs) |
| 897 | self.patience = state_dict.get("patience", self.patience) |
| 898 | self.verbose = state_dict.get("verbose", self.verbose) |
| 899 | self.cooldown = state_dict.get("cooldown", self.cooldown) |
| 900 | self.warmup = state_dict.get("warmup", self.warmup) |
| 901 | self.cooldown_counter = state_dict.get("cooldown_counter", self.cooldown_counter) |
| 902 | self.warmup_counter = state_dict.get("warmup_counter", self.warmup_counter) |
| 903 | self.mode = state_dict.get("mode", self.mode) |
| 904 | self.threshold = state_dict.get("threshold", self.threshold) |
| 905 | self.threshold_mode = state_dict.get("threshold_mode", self.threshold_mode) |
| 906 | self.best = state_dict.get("best", self.best) |
| 907 | self.num_bad_epochs = state_dict.get("num_bad_epochs", self.num_bad_epochs) |
| 908 | self.num_good_epochs = state_dict.get("num_good_epochs", self.num_good_epochs) |
| 909 | self.eps = state_dict.get("eps", self.eps) |
| 910 | self.last_epoch = state_dict.get("last_epoch", self.last_epoch) |
| 911 | self.smooth = state_dict.get("smooth", self.smooth) |
| 912 | self.window_size = state_dict.get("window_size", self.window_size) |
| 913 | self.reset_start = state_dict.get("reset_start", self.reset_start) |
| 914 | self.reset_start_original = state_dict.get("reset_start_original", self.reset_start_original) |
| 915 | self._last_lr = state_dict.get("_last_lr", self._last_lr) |
| 916 | self._init_lrs = state_dict.get("_init_lrs", self._init_lrs) |
| 917 | |
| 918 | if "_streaming_avg" in state_dict: |
| 919 | if self._streaming_avg is None: |
| 920 | self._streaming_avg = StreamingAverage(self.window_size) |
| 921 | self._streaming_avg.load_state_dict(state_dict["_streaming_avg"]) |
| 922 | |
| 923 | if "_last_lr" in state_dict: |
| 924 | for param_group, lr in zip(self.optimizer.param_groups, self._last_lr): |
| 925 | param_group["lr"] = lr |
| 926 | |
| 927 | |
| 928 | def get_greedy_schedule(optimizer: Optimizer, **kwargs): |
no test coverage detected