MCPcopy
hub / github.com/huggingface/transformers / BltLocalDecoder

Class BltLocalDecoder

src/transformers/models/blt/modular_blt.py:642–715  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

640
641
642class BltLocalDecoder(BltPreTrainedModel):
643 config: BltLocalDecoderConfig
644
645 def __init__(self, config: BltLocalDecoderConfig):
646 super().__init__(config)
647 self.gradient_checkpointing = False
648 self.config = config
649 self.cross_attn_decoder = True
650 self.layers = nn.ModuleList(
651 [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
652 )
653 self.rotary_emb = BltRotaryEmbedding(config=config)
654 self.patch_embedding_projection = nn.Linear(
655 in_features=config.hidden_size_global,
656 out_features=config.hidden_size * config.cross_attn_k,
657 bias=False,
658 )
659 self.norm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
660 self.cross_attn_layers = nn.ModuleList()
661 layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1
662 for layer_idx in range(layers_to_add):
663 self.cross_attn_layers.append(
664 BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size)
665 )
666
667 self.post_init()
668
669 def forward(
670 self,
671 input_ids: torch.LongTensor | None = None,
672 inputs_embeds: torch.Tensor | None = None,
673 patch_embeds: torch.Tensor | None = None,
674 attention_mask: torch.Tensor | None = None,
675 position_ids: torch.LongTensor | None = None,
676 past_key_values: Cache | None = None,
677 encoder_attention_mask: torch.Tensor | None = None,
678 **kwargs: Unpack[TransformersKwargs],
679 ):
680 batch_size = inputs_embeds.shape[0]
681 hidden_states = inputs_embeds
682 patch_embeds = self.patch_embedding_projection(patch_embeds)
683 patch_embeds = patch_embeds.reshape(
684 batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size
685 )
686
687 if patch_embeds is not None and not self.cross_attn_decoder:
688 hidden_states = hidden_states + patch_embeds
689
690 if position_ids is None:
691 position_ids = (
692 torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
693 )
694
695 position_embeddings = self.rotary_emb(hidden_states, position_ids)
696 hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
697
698 for i, layer in enumerate(self.layers):
699 if i == 0 or self.config.cross_attn_all_layers:

Callers 1

__init__Method · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected