| 99 | requires_grad=False) |
| 100 | |
| 101 | def append(self, logits: torch.Tensor): |
| 102 | if logits.ndim == 2: |
| 103 | logits = logits.unsqueeze(1) |
| 104 | assert logits.ndim == 3, f"Bad logits shape, expect [num_tokens, beam_width, vocab_size], got {logits.shape}" |
| 105 | |
| 106 | if self.use_chunked_generation_logits: |
| 107 | if self.beam_width == -1: |
| 108 | self._init_chunked_storage(logits) |
| 109 | self._add_fragment(logits) |
| 110 | else: |
| 111 | if self.beam_width == -1: |
| 112 | self._init(logits) |
| 113 | |
| 114 | assert logits.size(1) == self.beam_width, "Beam width mismatch" |
| 115 | |
| 116 | position = 0 if not self._logits_indices else self._logits_indices[ |
| 117 | -1][1] |
| 118 | new_position = logits.size(0) + position |
| 119 | if new_position > self.seq_length: |
| 120 | raise ValueError( |
| 121 | f"LogitsStorage overflow. This storage can only hold {self.seq_length} logits " |
| 122 | f"({position} already filled) but trying to append {logits.size(0)} more logits" |
| 123 | ) |
| 124 | |
| 125 | self._storage[position:new_position].copy_(logits, |
| 126 | non_blocking=True) |
| 127 | self._logits_indices.append((position, new_position)) |
| 128 | |
| 129 | def get(self, all_logits: bool, exclude_last: bool) -> torch.Tensor | None: |
| 130 | """Returns the used logits storage if there are any, otherwise, returns None. |