MCPcopy
hub / github.com/huggingface/transformers / _discover_fusable_modules

Function _discover_fusable_modules

src/transformers/fusion_mapping.py:142–187  ·  view source on GitHub ↗

Discover compatible module classes for one fusion family on a meta-initialized model. This function: - instantiates `cls(config)` on the meta device - scans `named_modules()` for candidate modules - optionally pre-filters them with `target_modules_patterns` - uses `is_fusable(..

(
    cls: "type[PreTrainedModel]",
    config: "PretrainedConfig",
    fusion_name: str,
    spec: ModuleFusionSpec,
)

Source from the content-addressed store, hash-verified

140
141
142def _discover_fusable_modules(
143 cls: "type[PreTrainedModel]",
144 config: "PretrainedConfig",
145 fusion_name: str,
146 spec: ModuleFusionSpec,
147) -> dict[str, type[nn.Module]]:
148 """Discover compatible module classes for one fusion family on a meta-initialized model.
149
150 This function:
151 - instantiates `cls(config)` on the meta device
152 - scans `named_modules()` for candidate modules
153 - optionally pre-filters them with `target_modules_patterns`
154 - uses `is_fusable(...)` as the final structural check
155 - builds the class-level patch mapping used by monkey patching
156
157 Results are cached per `(fusion_name, cls)` to avoid repeated meta-initialization.
158 This matches the current class-level fusion behavior, where one compatible
159 module class maps to one fused replacement class.
160 """
161
162 cache = _FUSION_DISCOVERY_CACHE.setdefault(fusion_name, {})
163 if cls in cache:
164 return cache[cls]
165
166 with torch.device("meta"):
167 model = cls(config)
168
169 seen_classes = set()
170 patch_mapping = {}
171 target_module_pattern = (
172 re.compile("|".join(spec.target_modules_patterns)) if spec.target_modules_patterns else None
173 )
174 for module_name, module in model.named_modules():
175 module_cls = type(module)
176 if module_cls in seen_classes:
177 continue
178 if target_module_pattern is not None and target_module_pattern.search(module_name) is None:
179 continue
180 if not spec.is_fusable(module):
181 continue
182
183 seen_classes.add(module_cls)
184 patch_mapping[module_cls.__name__] = spec.make_fused_class(module_cls)
185
186 cache[cls] = patch_mapping
187 return patch_mapping
188
189
190def _register_module_fusion(

Callers 1

_register_module_fusionFunction · 0.85

Calls 6

setdefaultMethod · 0.80
joinMethod · 0.80
deviceMethod · 0.45
is_fusableMethod · 0.45
addMethod · 0.45
make_fused_classMethod · 0.45

Tested by

no test coverage detected