A helper function to replace all `torch.nn.Linear` modules by `EetqLinear` modules. Parameters: model (`torch.nn.Module`): Input model or `torch.nn.Module` as the function is run recursively. modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`)
(model, modules_to_not_convert: list[str] | None = None, pre_quantized=False)
| 84 | |
| 85 | |
| 86 | def replace_with_eetq_linear(model, modules_to_not_convert: list[str] | None = None, pre_quantized=False): |
| 87 | """ |
| 88 | A helper function to replace all `torch.nn.Linear` modules by `EetqLinear` modules. |
| 89 | |
| 90 | Parameters: |
| 91 | model (`torch.nn.Module`): |
| 92 | Input model or `torch.nn.Module` as the function is run recursively. |
| 93 | modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`): |
| 94 | Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision |
| 95 | for numerical stability reasons. |
| 96 | """ |
| 97 | from .hub_kernels import get_kernel |
| 98 | |
| 99 | global eetq_kernels_hub |
| 100 | eetq_kernels_hub = get_kernel("kernels-community/quantization-eetq") |
| 101 | |
| 102 | has_been_replaced = False |
| 103 | # we need this to correctly materialize the weights during quantization |
| 104 | module_kwargs = {} if pre_quantized else {"dtype": None} |
| 105 | for module_name, module in model.named_modules(): |
| 106 | if not should_convert_module(module_name, modules_to_not_convert): |
| 107 | continue |
| 108 | with torch.device("meta"): |
| 109 | if isinstance(module, nn.Linear): |
| 110 | new_module = EetqLinear( |
| 111 | module.in_features, module.out_features, bias=module.bias is not None, **module_kwargs |
| 112 | ) |
| 113 | model.set_submodule(module_name, new_module) |
| 114 | has_been_replaced = True |
| 115 | |
| 116 | if not has_been_replaced: |
| 117 | logger.warning( |
| 118 | "You are loading your model using eetq but no linear modules were found in your model." |
| 119 | " Please double check your model architecture, or submit an issue on github if you think this is" |
| 120 | " a bug." |
| 121 | ) |
| 122 | |
| 123 | return model |