| 527 | |
| 528 | |
| 529 | class BltLocalEncoder(BltPreTrainedModel): |
| 530 | config: BltLocalEncoderConfig |
| 531 | _can_record_outputs = { |
| 532 | class="st">"encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name=class="st">"local_encoder"), |
| 533 | } |
| 534 | |
| 535 | def __init__(self, config: BltLocalEncoderConfig): |
| 536 | super().__init__(config) |
| 537 | self.gradient_checkpointing = False |
| 538 | self.config = config |
| 539 | self.layers = nn.ModuleList( |
| 540 | [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| 541 | ) |
| 542 | self.rotary_emb = BltRotaryEmbedding(config=config) |
| 543 | self.patch_embedding_projection = nn.Linear( |
| 544 | in_features=config.hidden_size, |
| 545 | out_features=config.hidden_size * config.cross_attn_k, |
| 546 | bias=False, |
| 547 | ) |
| 548 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
| 549 | self.cross_attn_layers = nn.ModuleList() |
| 550 | layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1 |
| 551 | for layer_idx in range(layers_to_add): |
| 552 | self.cross_attn_layers.append( |
| 553 | BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size) |
| 554 | ) |
| 555 | |
| 556 | self.post_init() |
| 557 | |
| 558 | def forward( |
| 559 | self, |
| 560 | input_ids: torch.LongTensor | None = None, |
| 561 | inputs_embeds: torch.Tensor | None = None, |
| 562 | patch_embeds: torch.Tensor | None = None, |
| 563 | attention_mask: torch.Tensor | None = None, |
| 564 | position_ids: torch.LongTensor | None = None, |
| 565 | past_key_values: Cache | None = None, |
| 566 | encoder_attention_mask: torch.Tensor | None = None, |
| 567 | num_patches: int | None = None, |
| 568 | patch_ids: torch.Tensor | None = None, |
| 569 | **kwargs: Unpack[TransformersKwargs], |
| 570 | ): |
| 571 | if inputs_embeds is None: |
| 572 | inputs_embeds = self.embed_tokens(input_ids) |
| 573 | |
| 574 | batch_size = inputs_embeds.shape[0] |
| 575 | hidden_states = F.dropout(inputs_embeds, p=self.config.dropout, training=self.training) |
| 576 | |
| 577 | if position_ids is None: |
| 578 | position_ids = ( |
| 579 | torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1) |
| 580 | ) |
| 581 | |
| 582 | position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| 583 | hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training) |
| 584 | |
| 585 | for idx, layer in enumerate(self.layers): |
| 586 | hidden_states = layer( |