(probs_sort)
| 141 | |
| 142 | # Copied from the gpt-fast repo |
| 143 | def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization |
| 144 | q = torch.empty_like(probs_sort).exponential_(1) |
| 145 | return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) |
| 146 | |
| 147 | def logits_to_probs(logits, temperature: float = 1.0, top_k: int | None = None): |
| 148 | logits = logits / max(temperature, 1e-5) |