Apply Liger Kernel optimizations to a model instance. Liger Kernel provides optimized Triton kernels for common transformer operations. This function patches the model in-place with those kernels. Args: model: The model to patch. Must be a `PreTrainedModel` or a PEFT wrapp
(model, kernel_config)
| 26 | |
| 27 | |
| 28 | def apply_liger_kernel(model, kernel_config): |
| 29 | """ |
| 30 | Apply Liger Kernel optimizations to a model instance. |
| 31 | |
| 32 | Liger Kernel provides optimized Triton kernels for common transformer operations. |
| 33 | This function patches the model in-place with those kernels. |
| 34 | |
| 35 | Args: |
| 36 | model: The model to patch. Must be a `PreTrainedModel` or a PEFT wrapper around one. |
| 37 | kernel_config: Kernel configuration. |
| 38 | """ |
| 39 | if not is_liger_kernel_available(): |
| 40 | raise ImportError( |
| 41 | "You have set `use_liger_kernel` to `True` but liger-kernel >= 0.3.0 is not available. " |
| 42 | "Please install it with `pip install liger-kernel`" |
| 43 | ) |
| 44 | |
| 45 | from liger_kernel.transformers import _apply_liger_kernel_to_instance |
| 46 | |
| 47 | kernel_config = kernel_config or {} |
| 48 | base_model = unwrap_peft_model(model) |
| 49 | |
| 50 | if isinstance(base_model, PreTrainedModel): |
| 51 | _apply_liger_kernel_to_instance(model=base_model, **kernel_config) |
| 52 | else: |
| 53 | logger.warning("The model is not an instance of PreTrainedModel. No liger kernels will be applied.") |