Discover compatible module classes for one fusion family on a meta-initialized model. This function: - instantiates `cls(config)` on the meta device - scans `named_modules()` for candidate modules - optionally pre-filters them with `target_modules_patterns` - uses `is_fusable(..
(
cls: "type[PreTrainedModel]",
config: "PretrainedConfig",
fusion_name: str,
spec: ModuleFusionSpec,
)
| 140 | |
| 141 | |
| 142 | def _discover_fusable_modules( |
| 143 | cls: "type[PreTrainedModel]", |
| 144 | config: "PretrainedConfig", |
| 145 | fusion_name: str, |
| 146 | spec: ModuleFusionSpec, |
| 147 | ) -> dict[str, type[nn.Module]]: |
| 148 | """Discover compatible module classes for one fusion family on a meta-initialized model. |
| 149 | |
| 150 | This function: |
| 151 | - instantiates `cls(config)` on the meta device |
| 152 | - scans `named_modules()` for candidate modules |
| 153 | - optionally pre-filters them with `target_modules_patterns` |
| 154 | - uses `is_fusable(...)` as the final structural check |
| 155 | - builds the class-level patch mapping used by monkey patching |
| 156 | |
| 157 | Results are cached per `(fusion_name, cls)` to avoid repeated meta-initialization. |
| 158 | This matches the current class-level fusion behavior, where one compatible |
| 159 | module class maps to one fused replacement class. |
| 160 | """ |
| 161 | |
| 162 | cache = _FUSION_DISCOVERY_CACHE.setdefault(fusion_name, {}) |
| 163 | if cls in cache: |
| 164 | return cache[cls] |
| 165 | |
| 166 | with torch.device("meta"): |
| 167 | model = cls(config) |
| 168 | |
| 169 | seen_classes = set() |
| 170 | patch_mapping = {} |
| 171 | target_module_pattern = ( |
| 172 | re.compile("|".join(spec.target_modules_patterns)) if spec.target_modules_patterns else None |
| 173 | ) |
| 174 | for module_name, module in model.named_modules(): |
| 175 | module_cls = type(module) |
| 176 | if module_cls in seen_classes: |
| 177 | continue |
| 178 | if target_module_pattern is not None and target_module_pattern.search(module_name) is None: |
| 179 | continue |
| 180 | if not spec.is_fusable(module): |
| 181 | continue |
| 182 | |
| 183 | seen_classes.add(module_cls) |
| 184 | patch_mapping[module_cls.__name__] = spec.make_fused_class(module_cls) |
| 185 | |
| 186 | cache[cls] = patch_mapping |
| 187 | return patch_mapping |
| 188 | |
| 189 | |
| 190 | def _register_module_fusion( |
no test coverage detected