If callback states exist and were passed in, restore their states if enabled
(self)
| 3716 | reissue_pt_warnings(caught_warnings) |
| 3717 | |
| 3718 | def _load_callback_state(self) -> None: |
| 3719 | """If callback states exist and were passed in, restore their states if enabled""" |
| 3720 | if not self.args.restore_callback_states_from_checkpoint: |
| 3721 | return |
| 3722 | # Callback states are stored in stateful_callbacks |
| 3723 | not_found = [] |
| 3724 | new_callbacks = [] |
| 3725 | original_callbacks = self.callback_handler.callbacks + [self.control] |
| 3726 | for stored_callback, data in self.state.stateful_callbacks.items(): |
| 3727 | if not isinstance(data, list): |
| 3728 | data = [data] |
| 3729 | if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks): |
| 3730 | # We can load/restore from multiple callbacks of the same type. |
| 3731 | duplicates = [ |
| 3732 | callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback |
| 3733 | ] |
| 3734 | for callback, callback_data in zip(duplicates, data): |
| 3735 | args = callback_data.get("args", {}) |
| 3736 | attributes = callback_data.get("attributes", {}) |
| 3737 | new_callback = type(callback)(**args) |
| 3738 | for attribute, value in attributes.items(): |
| 3739 | setattr(new_callback, attribute, value) |
| 3740 | if isinstance(callback, TrainerControl): |
| 3741 | # Specifically for restoring the `control` state |
| 3742 | self.control = new_callback |
| 3743 | else: |
| 3744 | new_callbacks.append(new_callback) |
| 3745 | # We remove the existing callback and add it to the list of new callbacks |
| 3746 | self.callback_handler.remove_callback(type(new_callback)) |
| 3747 | logger.info("Continuing training from checkpoint, restoring any callbacks that were passed in") |
| 3748 | else: |
| 3749 | not_found.append(stored_callback) |
| 3750 | if len(not_found) > 0: |
| 3751 | logger.warning( |
| 3752 | f"Checkpoint included callbacks not included in current configuration. Ignoring. ({', '.join(not_found)})" |
| 3753 | ) |
| 3754 | for callback in new_callbacks: |
| 3755 | self.callback_handler.add_callback(callback) |
| 3756 | |
| 3757 | def _issue_warnings_after_load(self, load_result: Any) -> None: |
| 3758 | """Log warnings for missing or unexpected keys after loading a checkpoint.""" |