MCPcopy
hub / github.com/huggingface/transformers / BltGlobalTransformer

Class BltGlobalTransformer

src/transformers/models/blt/modular_blt.py:718–766  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

716
717
718class BltGlobalTransformer(BltPreTrainedModel):
719 config: BltGlobalTransformerConfig
720 _can_record_outputs = {
721 "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"),
722 }
723
724 def __init__(self, config: BltGlobalTransformerConfig):
725 super().__init__(config)
726 self.config = config
727 self.layers = nn.ModuleList()
728 for layer_idx in range(config.num_hidden_layers):
729 self.layers.append(BltTransformerLayer(config, layer_idx))
730 self.rotary_emb = BltRotaryEmbedding(config=config)
731
732 # Create token embedding projection (use nn.Identity() when no projection needed)
733 if getattr(config, "encoder_cross_output_size", None) is not None:
734 self.token_embedding_projection = nn.Linear(
735 config.encoder_cross_output_size, config.hidden_size, bias=False
736 )
737 else:
738 self.token_embedding_projection = nn.Identity()
739
740 self.post_init()
741
742 def forward(
743 self,
744 inputs_embeds: torch.Tensor,
745 attention_mask: torch.Tensor | None = None,
746 position_ids: torch.LongTensor | None = None,
747 past_key_values: Cache | None = None,
748 **kwargs: Unpack[TransformersKwargs],
749 ):
750 batch_size, seq_len, _ = inputs_embeds.shape
751 hidden_states = self.token_embedding_projection(inputs_embeds)
752 hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
753 if position_ids is None:
754 position_ids = (
755 torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
756 )
757 position_embeddings = self.rotary_emb(hidden_states, position_ids)
758 for i, layer in enumerate(self.layers):
759 hidden_states = layer(
760 hidden_states,
761 position_embeddings=position_embeddings,
762 attention_mask=attention_mask,
763 past_key_values=past_key_values,
764 **kwargs,
765 )
766 return hidden_states
767
768
769class BltPatcher(BltPreTrainedModel):

Callers 1

__init__Method · 0.70

Calls 1

OutputRecorderClass · 0.85

Tested by

no test coverage detected