Saves a model checkpoint on TPU/XLA devices. Handles FSDP v1 sharded checkpoints (with consolidation on master), as well as standard XLA model saving via `save_pretrained` or `xm.save`. Args: model (`torch.nn.Module`): The model to save. args (`TrainingArguments`):
(model, args, accelerator, processing_class, is_fsdp_xla_v1_enabled, output_dir=None)
| 163 | |
| 164 | |
| 165 | def save_tpu_checkpoint(model, args, accelerator, processing_class, is_fsdp_xla_v1_enabled, output_dir=None): |
| 166 | """ |
| 167 | Saves a model checkpoint on TPU/XLA devices. |
| 168 | |
| 169 | Handles FSDP v1 sharded checkpoints (with consolidation on master), as well as |
| 170 | standard XLA model saving via `save_pretrained` or `xm.save`. |
| 171 | |
| 172 | Args: |
| 173 | model (`torch.nn.Module`): The model to save. |
| 174 | args (`TrainingArguments`): The training arguments. |
| 175 | accelerator (`Accelerator`): The accelerator instance. |
| 176 | processing_class: The processing class (tokenizer/processor) to save alongside the model. |
| 177 | is_fsdp_xla_v1_enabled (`bool`): Whether FSDP XLA v1 is enabled. |
| 178 | output_dir (`str`, *optional*): The directory to save to. Defaults to `args.output_dir`. |
| 179 | """ |
| 180 | import torch_xla.core.xla_model as xm |
| 181 | |
| 182 | output_dir = output_dir if output_dir is not None else args.output_dir |
| 183 | |
| 184 | logger.info(f"Saving model checkpoint to {output_dir}") |
| 185 | xm.mark_step() |
| 186 | |
| 187 | if xm.is_master_ordinal(local=False): |
| 188 | os.makedirs(output_dir, exist_ok=True) |
| 189 | torch.save(args, os.path.join(output_dir, "training_args.bin")) |
| 190 | |
| 191 | # Save a trained model and configuration using `save_pretrained()`. |
| 192 | # They can then be reloaded using `from_pretrained()` |
| 193 | supported_classes = (PushToHubMixin,) |
| 194 | xm.rendezvous("saving_checkpoint") |
| 195 | if is_fsdp_xla_v1_enabled: |
| 196 | ckpt = { |
| 197 | "model": model.state_dict(), |
| 198 | "shard_metadata": model.get_shard_metadata(), |
| 199 | } |
| 200 | ckpt_path = os.path.join(output_dir, f"rank{args.process_index}-of-{args.world_size}-{WEIGHTS_NAME}") |
| 201 | # All ranks save sharded checkpoint |
| 202 | xm.save(ckpt, ckpt_path, master_only=False) |
| 203 | # Make sure all ranks have saved checkpoints |
| 204 | xm.rendezvous("save_full_checkpoints") |
| 205 | # Master save full checkpoint |
| 206 | if args.should_save: |
| 207 | from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints |
| 208 | |
| 209 | full_state_dict, _ = consolidate_sharded_model_checkpoints( |
| 210 | ckpt_prefix=os.path.join(output_dir, ""), |
| 211 | ckpt_suffix=f"rank*-of-*-{WEIGHTS_NAME}", |
| 212 | save_model=False, |
| 213 | ) |
| 214 | model = model.module.module |
| 215 | unwrapped_model = accelerator.unwrap_model(model) |
| 216 | if isinstance(unwrapped_model, supported_classes): |
| 217 | unwrapped_model.save_pretrained(output_dir, state_dict=full_state_dict) |
| 218 | else: |
| 219 | logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") |
| 220 | xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) |
| 221 | elif not isinstance(model, supported_classes): |
| 222 | if isinstance(accelerator.unwrap_model(model), supported_classes): |
no test coverage detected