| 501 | |
| 502 | @auto_docstring |
| 503 | class DogePreTrainedModel(PreTrainedModel): |
| 504 | config: DogeConfig |
| 505 | base_model_prefix = "model" |
| 506 | supports_gradient_checkpointing = True |
| 507 | _no_split_modules = ["DogeDecoderLayer"] |
| 508 | _skip_keys_device_placement = ["past_key_values"] |
| 509 | _supports_flash_attn = False |
| 510 | _supports_sdpa = True |
| 511 | _supports_flex_attn = True |
| 512 | _can_compile_fullgraph = False |
| 513 | _supports_attention_backend = True |
| 514 | _can_record_outputs = { |
| 515 | "router_logits": OutputRecorder(DogeCDMoE, index=1), |
| 516 | "hidden_states": DogeDecoderLayer, |
| 517 | "attentions": DogeAttention, |
| 518 | } |
| 519 | |
| 520 | @torch.no_grad() |
| 521 | def _init_weights(self, module): |
| 522 | """Initialize the weights""" |
| 523 | super()._init_weights(module) |
| 524 | if isinstance(module, DogeAttention): |
| 525 | if hasattr(module, "A"): |
| 526 | init.zeros_(module.A) |
| 527 | elif isinstance(module, DogeDecoderLayer): |
| 528 | if hasattr(module, "input_residual"): |
| 529 | init.ones_(module.input_residual) |
| 530 | if hasattr(module, "post_attention_residual"): |
| 531 | init.ones_(module.post_attention_residual) |
| 532 | |
| 533 | |
| 534 | @auto_docstring |
nothing calls this directly
no test coverage detected