| 222 | |
| 223 | |
| 224 | class HYV3PreTrainedModel(LlamaPreTrainedModel): |
| 225 | # Not supporting multi-token prediction (MTP) atm |
| 226 | _keys_to_ignore_on_load_unexpected = [r"model\.layers\.80.*"] |
| 227 | _keep_in_fp32_modules_strict = ["e_score_correction_bias"] |
| 228 | _can_record_outputs = { |
| 229 | "router_logits": OutputRecorder(HYV3TopKRouter, index=0), |
| 230 | "hidden_states": HYV3DecoderLayer, |
| 231 | "attentions": HYV3Attention, |
| 232 | } |
| 233 | |
| 234 | @torch.no_grad() |
| 235 | def _init_weights(self, module): |
| 236 | PreTrainedModel._init_weights(self, module) |
| 237 | std = self.config.initializer_range |
| 238 | if isinstance(module, HYV3TopKRouter): |
| 239 | init.normal_(module.weight, mean=0.0, std=std) |
| 240 | elif isinstance(module, HYV3Experts): |
| 241 | init.normal_(module.gate_up_proj, mean=0.0, std=std) |
| 242 | init.normal_(module.down_proj, mean=0.0, std=std) |
| 243 | elif isinstance(module, HYV3MoE): |
| 244 | init.zeros_(module.e_score_correction_bias) |
| 245 | |
| 246 | |
| 247 | class HYV3Model(MiniMaxM2Model): |
nothing calls this directly
no test coverage detected