Deactivates NEFTune on the model. Args: model (`torch.nn.Module`): The model to deactivate NEFTune on. hook_handle (`torch.utils.hooks.RemovableHandle`): The hook handle returned by `activate_neftune`. accelerator (`Accelerator`, *optional*):
(model, hook_handle, accelerator=None)
| 87 | |
| 88 | |
| 89 | def deactivate_neftune(model, hook_handle, accelerator=None): |
| 90 | """ |
| 91 | Deactivates NEFTune on the model. |
| 92 | |
| 93 | Args: |
| 94 | model (`torch.nn.Module`): |
| 95 | The model to deactivate NEFTune on. |
| 96 | hook_handle (`torch.utils.hooks.RemovableHandle`): |
| 97 | The hook handle returned by `activate_neftune`. |
| 98 | accelerator (`Accelerator`, *optional*): |
| 99 | The accelerator instance. If provided, the model will be unwrapped before |
| 100 | accessing embeddings. |
| 101 | """ |
| 102 | if accelerator is not None: |
| 103 | unwrapped_model = accelerator.unwrap_model(model) |
| 104 | else: |
| 105 | unwrapped_model = model |
| 106 | |
| 107 | if _is_peft_model(unwrapped_model): |
| 108 | embeddings = unwrapped_model.base_model.model.get_input_embeddings() |
| 109 | else: |
| 110 | embeddings = unwrapped_model.get_input_embeddings() |
| 111 | |
| 112 | hook_handle.remove() |
| 113 | if hasattr(embeddings, "neftune_noise_alpha"): |
| 114 | del embeddings.neftune_noise_alpha |
no test coverage detected