MCPcopy
hub / github.com/huggingface/transformers / load_state_dict

Method load_state_dict

src/transformers/optimization.py:892–925  ·  view source on GitHub ↗

Load state from a dictionary.

(self, state_dict: dict[str, Any])

Source from the content-addressed store, hash-verified

890 return state
891
892 def load_state_dict(self, state_dict: dict[str, Any]) -> None:
893 """Load state from a dictionary."""
894 self.factor = state_dict.get("factor", self.factor)
895 self.min_lrs = state_dict.get("min_lrs", self.min_lrs)
896 self.max_lrs = state_dict.get("max_lrs", self.max_lrs)
897 self.patience = state_dict.get("patience", self.patience)
898 self.verbose = state_dict.get("verbose", self.verbose)
899 self.cooldown = state_dict.get("cooldown", self.cooldown)
900 self.warmup = state_dict.get("warmup", self.warmup)
901 self.cooldown_counter = state_dict.get("cooldown_counter", self.cooldown_counter)
902 self.warmup_counter = state_dict.get("warmup_counter", self.warmup_counter)
903 self.mode = state_dict.get("mode", self.mode)
904 self.threshold = state_dict.get("threshold", self.threshold)
905 self.threshold_mode = state_dict.get("threshold_mode", self.threshold_mode)
906 self.best = state_dict.get("best", self.best)
907 self.num_bad_epochs = state_dict.get("num_bad_epochs", self.num_bad_epochs)
908 self.num_good_epochs = state_dict.get("num_good_epochs", self.num_good_epochs)
909 self.eps = state_dict.get("eps", self.eps)
910 self.last_epoch = state_dict.get("last_epoch", self.last_epoch)
911 self.smooth = state_dict.get("smooth", self.smooth)
912 self.window_size = state_dict.get("window_size", self.window_size)
913 self.reset_start = state_dict.get("reset_start", self.reset_start)
914 self.reset_start_original = state_dict.get("reset_start_original", self.reset_start_original)
915 self._last_lr = state_dict.get("_last_lr", self._last_lr)
916 self._init_lrs = state_dict.get("_init_lrs", self._init_lrs)
917
918 if "_streaming_avg" in state_dict:
919 if self._streaming_avg is None:
920 self._streaming_avg = StreamingAverage(self.window_size)
921 self._streaming_avg.load_state_dict(state_dict["_streaming_avg"])
922
923 if "_last_lr" in state_dict:
924 for param_group, lr in zip(self.optimizer.param_groups, self._last_lr):
925 param_group["lr"] = lr
926
927
928def get_greedy_schedule(optimizer: Optimizer, **kwargs):

Callers 6

load_sharded_checkpointFunction · 0.45
_load_from_checkpointMethod · 0.45
_load_best_modelMethod · 0.45
opt_load_hookMethod · 0.45
_load_scalerMethod · 0.45

Calls 2

StreamingAverageClass · 0.85
getMethod · 0.45

Tested by

no test coverage detected