Aligns the special tokens of the tokenizer with the model configs. A new tokens may be defined in the tokenizer for fine-tuning purposes, e.g. an "end of turn" token may be added on chat models. In that case, we want the model configs to be aligned with the tokenizer, so that all d
(model, processing_class)
| 1170 | |
| 1171 | |
| 1172 | def align_special_tokens(model, processing_class): |
| 1173 | """ |
| 1174 | Aligns the special tokens of the tokenizer with the model configs. |
| 1175 | |
| 1176 | A new tokens may be defined in the tokenizer for fine-tuning purposes, e.g. an "end of turn" token may be |
| 1177 | added on chat models. In that case, we want the model configs to be aligned with the tokenizer, so that all |
| 1178 | downstream uses work as expected. This alignment should happen before training, to ensure the prediction step |
| 1179 | uses the new tokens as well. |
| 1180 | """ |
| 1181 | from .processing_utils import ProcessorMixin |
| 1182 | from .tokenization_utils_base import PreTrainedTokenizerBase |
| 1183 | |
| 1184 | if isinstance(processing_class, ProcessorMixin): |
| 1185 | tokenizer: PreTrainedTokenizerBase = processing_class.tokenizer |
| 1186 | else: |
| 1187 | tokenizer = processing_class |
| 1188 | model_has_generation_config = hasattr(model, "generation_config") and model.generation_config is not None |
| 1189 | updated_tokens = {} |
| 1190 | |
| 1191 | # 1 - Align EOS token. EOS is more complex than the others, as `generation_config` may hold more than one EOS |
| 1192 | # token. |
| 1193 | tokenizer_has_new_eos = tokenizer.eos_token_id != getattr(model.config, "eos_token_id", None) |
| 1194 | if model_has_generation_config: |
| 1195 | # `generation_config.eos_token_id` is None: direct comparison |
| 1196 | if model.generation_config.eos_token_id is None: |
| 1197 | tokenizer_has_new_eos |= tokenizer.eos_token_id != model.generation_config.eos_token_id |
| 1198 | else: |
| 1199 | # `generation_config.eos_token_id` is an `int`: convert it to list (and continue below) |
| 1200 | if isinstance(model.generation_config.eos_token_id, int): |
| 1201 | model.generation_config.eos_token_id = [model.generation_config.eos_token_id] |
| 1202 | # `generation_config.eos_token_id` is a `list`: check if the tokenizer's EOS token is in the list |
| 1203 | tokenizer_has_new_eos |= tokenizer.eos_token_id not in model.generation_config.eos_token_id |
| 1204 | |
| 1205 | if tokenizer_has_new_eos: |
| 1206 | updated_tokens["eos_token_id"] = tokenizer.eos_token_id |
| 1207 | model.config.eos_token_id = tokenizer.eos_token_id |
| 1208 | # The generation config may hold more than one EOS token. We preserve the original EOS tokens: any of the |
| 1209 | # EOS tokens defined here will halt generation. |
| 1210 | if model_has_generation_config: |
| 1211 | all_eos_tokens = [tokenizer.eos_token_id] |
| 1212 | if model.generation_config.eos_token_id is not None: |
| 1213 | all_eos_tokens += list(model.generation_config.eos_token_id) |
| 1214 | model.generation_config.eos_token_id = [token for token in all_eos_tokens if token is not None] |
| 1215 | |
| 1216 | # 2 - Align BOS |
| 1217 | tokenizer_has_new_bos = tokenizer.bos_token_id != getattr(model.config, "bos_token_id", None) |
| 1218 | if model_has_generation_config: |
| 1219 | tokenizer_has_new_bos |= tokenizer.bos_token_id != model.generation_config.bos_token_id |
| 1220 | |
| 1221 | if tokenizer_has_new_bos: |
| 1222 | updated_tokens["bos_token_id"] = tokenizer.bos_token_id |
| 1223 | model.config.bos_token_id = tokenizer.bos_token_id |
| 1224 | if model_has_generation_config: |
| 1225 | model.generation_config.bos_token_id = tokenizer.bos_token_id |
| 1226 | |
| 1227 | # 3 - Align PAD |
| 1228 | tokenizer_has_new_pad = tokenizer.pad_token_id != getattr(model.config, "pad_token_id", None) |
| 1229 | if model_has_generation_config: |