| 640 | |
| 641 | |
| 642 | class BltLocalDecoder(BltPreTrainedModel): |
| 643 | config: BltLocalDecoderConfig |
| 644 | |
| 645 | def __init__(self, config: BltLocalDecoderConfig): |
| 646 | super().__init__(config) |
| 647 | self.gradient_checkpointing = False |
| 648 | self.config = config |
| 649 | self.cross_attn_decoder = True |
| 650 | self.layers = nn.ModuleList( |
| 651 | [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| 652 | ) |
| 653 | self.rotary_emb = BltRotaryEmbedding(config=config) |
| 654 | self.patch_embedding_projection = nn.Linear( |
| 655 | in_features=config.hidden_size_global, |
| 656 | out_features=config.hidden_size * config.cross_attn_k, |
| 657 | bias=False, |
| 658 | ) |
| 659 | self.norm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| 660 | self.cross_attn_layers = nn.ModuleList() |
| 661 | layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 |
| 662 | for layer_idx in range(layers_to_add): |
| 663 | self.cross_attn_layers.append( |
| 664 | BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) |
| 665 | ) |
| 666 | |
| 667 | self.post_init() |
| 668 | |
| 669 | def forward( |
| 670 | self, |
| 671 | input_ids: torch.LongTensor | None = None, |
| 672 | inputs_embeds: torch.Tensor | None = None, |
| 673 | patch_embeds: torch.Tensor | None = None, |
| 674 | attention_mask: torch.Tensor | None = None, |
| 675 | position_ids: torch.LongTensor | None = None, |
| 676 | past_key_values: Cache | None = None, |
| 677 | encoder_attention_mask: torch.Tensor | None = None, |
| 678 | **kwargs: Unpack[TransformersKwargs], |
| 679 | ): |
| 680 | batch_size = inputs_embeds.shape[0] |
| 681 | hidden_states = inputs_embeds |
| 682 | patch_embeds = self.patch_embedding_projection(patch_embeds) |
| 683 | patch_embeds = patch_embeds.reshape( |
| 684 | batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size |
| 685 | ) |
| 686 | |
| 687 | if patch_embeds is not None and not self.cross_attn_decoder: |
| 688 | hidden_states = hidden_states + patch_embeds |
| 689 | |
| 690 | if position_ids is None: |
| 691 | position_ids = ( |
| 692 | torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1) |
| 693 | ) |
| 694 | |
| 695 | position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| 696 | hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) |
| 697 | |
| 698 | for i, layer in enumerate(self.layers): |
| 699 | if i == 0 or self.config.cross_attn_all_layers: |