MCPcopy
hub / github.com/huggingface/transformers / save_tpu_checkpoint

Function save_tpu_checkpoint

src/transformers/integrations/tpu.py:165–239  ·  view source on GitHub ↗

Saves a model checkpoint on TPU/XLA devices. Handles FSDP v1 sharded checkpoints (with consolidation on master), as well as standard XLA model saving via `save_pretrained` or `xm.save`. Args: model (`torch.nn.Module`): The model to save. args (`TrainingArguments`):

(model, args, accelerator, processing_class, is_fsdp_xla_v1_enabled, output_dir=None)

Source from the content-addressed store, hash-verified

163
164
165def save_tpu_checkpoint(model, args, accelerator, processing_class, is_fsdp_xla_v1_enabled, output_dir=None):
166 """
167 Saves a model checkpoint on TPU/XLA devices.
168
169 Handles FSDP v1 sharded checkpoints (with consolidation on master), as well as
170 standard XLA model saving via `save_pretrained` or `xm.save`.
171
172 Args:
173 model (`torch.nn.Module`): The model to save.
174 args (`TrainingArguments`): The training arguments.
175 accelerator (`Accelerator`): The accelerator instance.
176 processing_class: The processing class (tokenizer/processor) to save alongside the model.
177 is_fsdp_xla_v1_enabled (`bool`): Whether FSDP XLA v1 is enabled.
178 output_dir (`str`, *optional*): The directory to save to. Defaults to `args.output_dir`.
179 """
180 import torch_xla.core.xla_model as xm
181
182 output_dir = output_dir if output_dir is not None else args.output_dir
183
184 logger.info(f"Saving model checkpoint to {output_dir}")
185 xm.mark_step()
186
187 if xm.is_master_ordinal(local=False):
188 os.makedirs(output_dir, exist_ok=True)
189 torch.save(args, os.path.join(output_dir, "training_args.bin"))
190
191 # Save a trained model and configuration using `save_pretrained()`.
192 # They can then be reloaded using `from_pretrained()`
193 supported_classes = (PushToHubMixin,)
194 xm.rendezvous("saving_checkpoint")
195 if is_fsdp_xla_v1_enabled:
196 ckpt = {
197 "model": model.state_dict(),
198 "shard_metadata": model.get_shard_metadata(),
199 }
200 ckpt_path = os.path.join(output_dir, f"rank{args.process_index}-of-{args.world_size}-{WEIGHTS_NAME}")
201 # All ranks save sharded checkpoint
202 xm.save(ckpt, ckpt_path, master_only=False)
203 # Make sure all ranks have saved checkpoints
204 xm.rendezvous("save_full_checkpoints")
205 # Master save full checkpoint
206 if args.should_save:
207 from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints
208
209 full_state_dict, _ = consolidate_sharded_model_checkpoints(
210 ckpt_prefix=os.path.join(output_dir, ""),
211 ckpt_suffix=f"rank*-of-*-{WEIGHTS_NAME}",
212 save_model=False,
213 )
214 model = model.module.module
215 unwrapped_model = accelerator.unwrap_model(model)
216 if isinstance(unwrapped_model, supported_classes):
217 unwrapped_model.save_pretrained(output_dir, state_dict=full_state_dict)
218 else:
219 logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
220 xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME))
221 elif not isinstance(model, supported_classes):
222 if isinstance(accelerator.unwrap_model(model), supported_classes):

Callers 1

save_modelMethod · 0.85

Calls 5

joinMethod · 0.80
infoMethod · 0.45
saveMethod · 0.45
state_dictMethod · 0.45
save_pretrainedMethod · 0.45

Tested by

no test coverage detected