Map inputs to attention experts according to routing decision and compute query projection inside each experts.
(self, layer_input)
| 237 | ) |
| 238 | |
| 239 | def map(self, layer_input): |
| 240 | """ |
| 241 | Map inputs to attention experts according to routing decision and compute query projection inside each experts. |
| 242 | """ |
| 243 | |
| 244 | # Compute gating topology |
| 245 | bsz, length, emb_size = layer_input.size() |
| 246 | layer_input = layer_input.reshape(-1, emb_size) # [bsz * length, emb_size] |
| 247 | index_sorted_experts, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input) |
| 248 | topo_info = (index_sorted_experts, batch_index, batch_gates, expert_size) |
| 249 | |
| 250 | # Group inputs according to topology and compute query projection |
| 251 | expert_inputs = layer_input[batch_index] # [bsz * length * top_k, emb_size] |
| 252 | expert_outputs = self.input_linear(expert_inputs, expert_size) # [bsz * length * top_k, hidden_size] |
| 253 | |
| 254 | # Ungroup queries back to original order |
| 255 | zeros = torch.zeros( |
| 256 | (bsz * length * self.top_k, self.hidden_size), dtype=expert_outputs.dtype, device=expert_outputs.device |
| 257 | ) |
| 258 | layer_output = zeros.index_add(0, index_sorted_experts, expert_outputs) |
| 259 | layer_output = layer_output.view(bsz, length, self.top_k, -1) # [bsz, length, top_k, hidden_size] |
| 260 | return layer_output, router_logits, topo_info |
| 261 | |
| 262 | def reduce(self, layer_input, topo_info): |
| 263 | """ |