MCPcopy
hub / github.com/huggingface/transformers / replace_with_eetq_linear

Function replace_with_eetq_linear

src/transformers/integrations/eetq.py:86–123  ·  view source on GitHub ↗

A helper function to replace all `torch.nn.Linear` modules by `EetqLinear` modules. Parameters: model (`torch.nn.Module`): Input model or `torch.nn.Module` as the function is run recursively. modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`)

(model, modules_to_not_convert: list[str] | None = None, pre_quantized=False)

Source from the content-addressed store, hash-verified

84
85
86def replace_with_eetq_linear(model, modules_to_not_convert: list[str] | None = None, pre_quantized=False):
87 """
88 A helper function to replace all `torch.nn.Linear` modules by `EetqLinear` modules.
89
90 Parameters:
91 model (`torch.nn.Module`):
92 Input model or `torch.nn.Module` as the function is run recursively.
93 modules_to_not_convert (`list[`str`]`, *optional*, defaults to `None`):
94 Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision
95 for numerical stability reasons.
96 """
97 from .hub_kernels import get_kernel
98
99 global eetq_kernels_hub
100 eetq_kernels_hub = get_kernel("kernels-community/quantization-eetq")
101
102 has_been_replaced = False
103 # we need this to correctly materialize the weights during quantization
104 module_kwargs = {} if pre_quantized else {"dtype": None}
105 for module_name, module in model.named_modules():
106 if not should_convert_module(module_name, modules_to_not_convert):
107 continue
108 with torch.device("meta"):
109 if isinstance(module, nn.Linear):
110 new_module = EetqLinear(
111 module.in_features, module.out_features, bias=module.bias is not None, **module_kwargs
112 )
113 model.set_submodule(module_name, new_module)
114 has_been_replaced = True
115
116 if not has_been_replaced:
117 logger.warning(
118 "You are loading your model using eetq but no linear modules were found in your model."
119 " Please double check your model architecture, or submit an issue on github if you think this is"
120 " a bug."
121 )
122
123 return model

Calls 5

get_kernelFunction · 0.85
should_convert_moduleFunction · 0.85
EetqLinearClass · 0.85
warningMethod · 0.80
deviceMethod · 0.45

Tested by 1