| 716 | |
| 717 | |
| 718 | class BltGlobalTransformer(BltPreTrainedModel): |
| 719 | config: BltGlobalTransformerConfig |
| 720 | _can_record_outputs = { |
| 721 | "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"), |
| 722 | } |
| 723 | |
| 724 | def __init__(self, config: BltGlobalTransformerConfig): |
| 725 | super().__init__(config) |
| 726 | self.config = config |
| 727 | self.layers = nn.ModuleList() |
| 728 | for layer_idx in range(config.num_hidden_layers): |
| 729 | self.layers.append(BltTransformerLayer(config, layer_idx)) |
| 730 | self.rotary_emb = BltRotaryEmbedding(config=config) |
| 731 | |
| 732 | # Create token embedding projection (use nn.Identity() when no projection needed) |
| 733 | if getattr(config, "encoder_cross_output_size", None) is not None: |
| 734 | self.token_embedding_projection = nn.Linear( |
| 735 | config.encoder_cross_output_size, config.hidden_size, bias=False |
| 736 | ) |
| 737 | else: |
| 738 | self.token_embedding_projection = nn.Identity() |
| 739 | |
| 740 | self.post_init() |
| 741 | |
| 742 | def forward( |
| 743 | self, |
| 744 | inputs_embeds: torch.Tensor, |
| 745 | attention_mask: torch.Tensor | None = None, |
| 746 | position_ids: torch.LongTensor | None = None, |
| 747 | past_key_values: Cache | None = None, |
| 748 | **kwargs: Unpack[TransformersKwargs], |
| 749 | ): |
| 750 | batch_size, seq_len, _ = inputs_embeds.shape |
| 751 | hidden_states = self.token_embedding_projection(inputs_embeds) |
| 752 | hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) |
| 753 | if position_ids is None: |
| 754 | position_ids = ( |
| 755 | torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1) |
| 756 | ) |
| 757 | position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| 758 | for i, layer in enumerate(self.layers): |
| 759 | hidden_states = layer( |
| 760 | hidden_states, |
| 761 | position_embeddings=position_embeddings, |
| 762 | attention_mask=attention_mask, |
| 763 | past_key_values=past_key_values, |
| 764 | **kwargs, |
| 765 | ) |
| 766 | return hidden_states |
| 767 | |
| 768 | |
| 769 | class BltPatcher(BltPreTrainedModel): |