MCPcopy
hub / github.com/huggingface/transformers / align_special_tokens

Function align_special_tokens

src/transformers/trainer_utils.py:1172–1244  ·  view source on GitHub ↗

Aligns the special tokens of the tokenizer with the model configs. A new tokens may be defined in the tokenizer for fine-tuning purposes, e.g. an "end of turn" token may be added on chat models. In that case, we want the model configs to be aligned with the tokenizer, so that all d

(model, processing_class)

Source from the content-addressed store, hash-verified

1170
1171
1172def align_special_tokens(model, processing_class):
1173 """
1174 Aligns the special tokens of the tokenizer with the model configs.
1175
1176 A new tokens may be defined in the tokenizer for fine-tuning purposes, e.g. an "end of turn" token may be
1177 added on chat models. In that case, we want the model configs to be aligned with the tokenizer, so that all
1178 downstream uses work as expected. This alignment should happen before training, to ensure the prediction step
1179 uses the new tokens as well.
1180 """
1181 from .processing_utils import ProcessorMixin
1182 from .tokenization_utils_base import PreTrainedTokenizerBase
1183
1184 if isinstance(processing_class, ProcessorMixin):
1185 tokenizer: PreTrainedTokenizerBase = processing_class.tokenizer
1186 else:
1187 tokenizer = processing_class
1188 model_has_generation_config = hasattr(model, "generation_config") and model.generation_config is not None
1189 updated_tokens = {}
1190
1191 # 1 - Align EOS token. EOS is more complex than the others, as `generation_config` may hold more than one EOS
1192 # token.
1193 tokenizer_has_new_eos = tokenizer.eos_token_id != getattr(model.config, "eos_token_id", None)
1194 if model_has_generation_config:
1195 # `generation_config.eos_token_id` is None: direct comparison
1196 if model.generation_config.eos_token_id is None:
1197 tokenizer_has_new_eos |= tokenizer.eos_token_id != model.generation_config.eos_token_id
1198 else:
1199 # `generation_config.eos_token_id` is an `int`: convert it to list (and continue below)
1200 if isinstance(model.generation_config.eos_token_id, int):
1201 model.generation_config.eos_token_id = [model.generation_config.eos_token_id]
1202 # `generation_config.eos_token_id` is a `list`: check if the tokenizer's EOS token is in the list
1203 tokenizer_has_new_eos |= tokenizer.eos_token_id not in model.generation_config.eos_token_id
1204
1205 if tokenizer_has_new_eos:
1206 updated_tokens["eos_token_id"] = tokenizer.eos_token_id
1207 model.config.eos_token_id = tokenizer.eos_token_id
1208 # The generation config may hold more than one EOS token. We preserve the original EOS tokens: any of the
1209 # EOS tokens defined here will halt generation.
1210 if model_has_generation_config:
1211 all_eos_tokens = [tokenizer.eos_token_id]
1212 if model.generation_config.eos_token_id is not None:
1213 all_eos_tokens += list(model.generation_config.eos_token_id)
1214 model.generation_config.eos_token_id = [token for token in all_eos_tokens if token is not None]
1215
1216 # 2 - Align BOS
1217 tokenizer_has_new_bos = tokenizer.bos_token_id != getattr(model.config, "bos_token_id", None)
1218 if model_has_generation_config:
1219 tokenizer_has_new_bos |= tokenizer.bos_token_id != model.generation_config.bos_token_id
1220
1221 if tokenizer_has_new_bos:
1222 updated_tokens["bos_token_id"] = tokenizer.bos_token_id
1223 model.config.bos_token_id = tokenizer.bos_token_id
1224 if model_has_generation_config:
1225 model.generation_config.bos_token_id = tokenizer.bos_token_id
1226
1227 # 3 - Align PAD
1228 tokenizer_has_new_pad = tokenizer.pad_token_id != getattr(model.config, "pad_token_id", None)
1229 if model_has_generation_config:

Callers 1

trainMethod · 0.85

Calls 1

warningMethod · 0.80

Tested by

no test coverage detected