MCPcopy
hub / github.com/huggingface/transformers / _load_callback_state

Method _load_callback_state

src/transformers/trainer.py:3718–3755  ·  view source on GitHub ↗

If callback states exist and were passed in, restore their states if enabled

(self)

Source from the content-addressed store, hash-verified

3716 reissue_pt_warnings(caught_warnings)
3717
3718 def _load_callback_state(self) -> None:
3719 """If callback states exist and were passed in, restore their states if enabled"""
3720 if not self.args.restore_callback_states_from_checkpoint:
3721 return
3722 # Callback states are stored in stateful_callbacks
3723 not_found = []
3724 new_callbacks = []
3725 original_callbacks = self.callback_handler.callbacks + [self.control]
3726 for stored_callback, data in self.state.stateful_callbacks.items():
3727 if not isinstance(data, list):
3728 data = [data]
3729 if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks):
3730 # We can load/restore from multiple callbacks of the same type.
3731 duplicates = [
3732 callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback
3733 ]
3734 for callback, callback_data in zip(duplicates, data):
3735 args = callback_data.get("args", {})
3736 attributes = callback_data.get("attributes", {})
3737 new_callback = type(callback)(**args)
3738 for attribute, value in attributes.items():
3739 setattr(new_callback, attribute, value)
3740 if isinstance(callback, TrainerControl):
3741 # Specifically for restoring the `control` state
3742 self.control = new_callback
3743 else:
3744 new_callbacks.append(new_callback)
3745 # We remove the existing callback and add it to the list of new callbacks
3746 self.callback_handler.remove_callback(type(new_callback))
3747 logger.info("Continuing training from checkpoint, restoring any callbacks that were passed in")
3748 else:
3749 not_found.append(stored_callback)
3750 if len(not_found) > 0:
3751 logger.warning(
3752 f"Checkpoint included callbacks not included in current configuration. Ignoring. ({', '.join(not_found)})"
3753 )
3754 for callback in new_callbacks:
3755 self.callback_handler.add_callback(callback)
3756
3757 def _issue_warnings_after_load(self, load_result: Any) -> None:
3758 """Log warnings for missing or unexpected keys after loading a checkpoint."""

Callers 2

_init_training_stateMethod · 0.95

Calls 7

warningMethod · 0.80
joinMethod · 0.80
itemsMethod · 0.45
getMethod · 0.45
remove_callbackMethod · 0.45
infoMethod · 0.45
add_callbackMethod · 0.45

Tested by 1