(logits, temperature: float = 1.0, top_k: int | None = None)
| 155 | return probs |
| 156 | |
| 157 | def sample(logits, temperature: float = 1.0, top_k: int | None = None): |
| 158 | probs = logits_to_probs(logits[0, -1], temperature, top_k) |
| 159 | idx_next = multinomial_sample_one_no_sync(probs) |
| 160 | return idx_next, probs |
| 161 | |
| 162 | # First eager forward pass |
| 163 | logger.info("running first eager forward pass") |
no test coverage detected