| 767 | |
| 768 | |
| 769 | class BltPatcher(BltPreTrainedModel): |
| 770 | config: BltPatcherConfig |
| 771 | |
| 772 | def __init__(self, config: BltPatcherConfig): |
| 773 | super().__init__(config) |
| 774 | self.rotary_emb = BltRotaryEmbedding(config=self.config) |
| 775 | self.layers = nn.ModuleList() |
| 776 | for layer_idx in range(self.config.num_hidden_layers): |
| 777 | self.layers.append(BltTransformerLayer(self.config, layer_idx)) |
| 778 | self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size) |
| 779 | self.norm = BltRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) |
| 780 | self.lm_head = nn.Linear( |
| 781 | self.config.hidden_size, |
| 782 | self.config.vocab_size, |
| 783 | bias=False, |
| 784 | ) |
| 785 | |
| 786 | self.post_init() |
| 787 | |
| 788 | def forward( |
| 789 | self, |
| 790 | input_ids: torch.LongTensor | None = None, |
| 791 | attention_mask: torch.Tensor | None = None, |
| 792 | position_ids: torch.LongTensor | None = None, |
| 793 | past_key_values: Cache | None = None, |
| 794 | inputs_embeds: torch.FloatTensor | None = None, |
| 795 | use_cache: bool | None = None, |
| 796 | patch_size: int | None = None, |
| 797 | threshold: float | None = None, |
| 798 | max_patch_length: int | None = None, |
| 799 | **kwargs: Unpack[TransformersKwargs], |
| 800 | ): |
| 801 | if (input_ids is None) ^ (inputs_embeds is not None): |
| 802 | raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
| 803 | |
| 804 | if inputs_embeds is None: |
| 805 | inputs_embeds = self.embed_tokens(input_ids) |
| 806 | |
| 807 | if use_cache and past_key_values is None: |
| 808 | past_key_values = DynamicCache(config=self.config) |
| 809 | |
| 810 | if position_ids is None: |
| 811 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| 812 | position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens |
| 813 | position_ids = position_ids.unsqueeze(0) |
| 814 | |
| 815 | causal_mask = create_causal_mask( |
| 816 | config=self.config, |
| 817 | inputs_embeds=inputs_embeds, |
| 818 | attention_mask=attention_mask, |
| 819 | past_key_values=past_key_values, |
| 820 | position_ids=position_ids, |
| 821 | ) |
| 822 | |
| 823 | hidden_states = inputs_embeds |
| 824 | position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| 825 | |
| 826 | for layer in self.layers: |