MCPcopy
hub / github.com/huggingface/transformers / BltRotaryEmbedding

Class BltRotaryEmbedding

src/transformers/models/blt/modular_blt.py:272–286  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

270
271
272class BltRotaryEmbedding(LlamaRotaryEmbedding):
273 @torch.no_grad()
274 @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
275 def forward(self, x, position_ids):
276 inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
277 position_ids_expanded = position_ids[:, None, :].float()
278
279 device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
280 with maybe_autocast(device_type=device_type, enabled=False): # Force float32
281 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
282 emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
283 cos = emb.cos() * self.attention_scaling
284 sin = emb.sin() * self.attention_scaling
285
286 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
287
288
289class BltTransformerLayer(MllamaSelfAttentionDecoderLayer):

Callers 4

__init__Method · 0.70
__init__Method · 0.70
__init__Method · 0.70
__init__Method · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected