MCPcopy
hub / github.com/huggingface/transformers / BltPatcher

Class BltPatcher

src/transformers/models/blt/modular_blt.py:769–899  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

767
768
769class BltPatcher(BltPreTrainedModel):
770 config: BltPatcherConfig
771
772 def __init__(self, config: BltPatcherConfig):
773 super().__init__(config)
774 self.rotary_emb = BltRotaryEmbedding(config=self.config)
775 self.layers = nn.ModuleList()
776 for layer_idx in range(self.config.num_hidden_layers):
777 self.layers.append(BltTransformerLayer(self.config, layer_idx))
778 self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
779 self.norm = BltRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
780 self.lm_head = nn.Linear(
781 self.config.hidden_size,
782 self.config.vocab_size,
783 bias=False,
784 )
785
786 self.post_init()
787
788 def forward(
789 self,
790 input_ids: torch.LongTensor | None = None,
791 attention_mask: torch.Tensor | None = None,
792 position_ids: torch.LongTensor | None = None,
793 past_key_values: Cache | None = None,
794 inputs_embeds: torch.FloatTensor | None = None,
795 use_cache: bool | None = None,
796 patch_size: int | None = None,
797 threshold: float | None = None,
798 max_patch_length: int | None = None,
799 **kwargs: Unpack[TransformersKwargs],
800 ):
801 if (input_ids is None) ^ (inputs_embeds is not None):
802 raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
803
804 if inputs_embeds is None:
805 inputs_embeds = self.embed_tokens(input_ids)
806
807 if use_cache and past_key_values is None:
808 past_key_values = DynamicCache(config=self.config)
809
810 if position_ids is None:
811 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
812 position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
813 position_ids = position_ids.unsqueeze(0)
814
815 causal_mask = create_causal_mask(
816 config=self.config,
817 inputs_embeds=inputs_embeds,
818 attention_mask=attention_mask,
819 past_key_values=past_key_values,
820 position_ids=position_ids,
821 )
822
823 hidden_states = inputs_embeds
824 position_embeddings = self.rotary_emb(hidden_states, position_ids)
825
826 for layer in self.layers:

Callers 1

__init__Method · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected