| 270 | |
| 271 | |
| 272 | class BltRotaryEmbedding(LlamaRotaryEmbedding): |
| 273 | @torch.no_grad() |
| 274 | @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) |
| 275 | def forward(self, x, position_ids): |
| 276 | inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) |
| 277 | position_ids_expanded = position_ids[:, None, :].float() |
| 278 | |
| 279 | device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| 280 | with maybe_autocast(device_type=device_type, enabled=False): # Force float32 |
| 281 | freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| 282 | emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat() |
| 283 | cos = emb.cos() * self.attention_scaling |
| 284 | sin = emb.sin() * self.attention_scaling |
| 285 | |
| 286 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
| 287 | |
| 288 | |
| 289 | class BltTransformerLayer(MllamaSelfAttentionDecoderLayer): |