MCPcopy
hub / github.com/huggingface/transformers / BltLocalEncoder

Class BltLocalEncoder

src/transformers/models/blt/modular_blt.py:529–639  ·  src/transformers/models/blt/modular_blt.py::BltLocalEncoder

Source from the content-addressed store, hash-verified

527
528
529class BltLocalEncoder(BltPreTrainedModel):
530 config: BltLocalEncoderConfig
531 _can_record_outputs = {
532 class="st">"encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name=class="st">"local_encoder"),
533 }
534
535 def __init__(self, config: BltLocalEncoderConfig):
536 super().__init__(config)
537 self.gradient_checkpointing = False
538 self.config = config
539 self.layers = nn.ModuleList(
540 [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
541 )
542 self.rotary_emb = BltRotaryEmbedding(config=config)
543 self.patch_embedding_projection = nn.Linear(
544 in_features=config.hidden_size,
545 out_features=config.hidden_size * config.cross_attn_k,
546 bias=False,
547 )
548 self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
549 self.cross_attn_layers = nn.ModuleList()
550 layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1
551 for layer_idx in range(layers_to_add):
552 self.cross_attn_layers.append(
553 BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size)
554 )
555
556 self.post_init()
557
558 def forward(
559 self,
560 input_ids: torch.LongTensor | None = None,
561 inputs_embeds: torch.Tensor | None = None,
562 patch_embeds: torch.Tensor | None = None,
563 attention_mask: torch.Tensor | None = None,
564 position_ids: torch.LongTensor | None = None,
565 past_key_values: Cache | None = None,
566 encoder_attention_mask: torch.Tensor | None = None,
567 num_patches: int | None = None,
568 patch_ids: torch.Tensor | None = None,
569 **kwargs: Unpack[TransformersKwargs],
570 ):
571 if inputs_embeds is None:
572 inputs_embeds = self.embed_tokens(input_ids)
573
574 batch_size = inputs_embeds.shape[0]
575 hidden_states = F.dropout(inputs_embeds, p=self.config.dropout, training=self.training)
576
577 if position_ids is None:
578 position_ids = (
579 torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
580 )
581
582 position_embeddings = self.rotary_emb(hidden_states, position_ids)
583 hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
584
585 for idx, layer in enumerate(self.layers):
586 hidden_states = layer(

Callers 1

__init__Method · 0.70

Calls 1

OutputRecorderClass · 0.85

Tested by

no test coverage detected