Wrap `model` for distributed training if needed (DDP, FSDP, SageMaker, etc.).
(self, model: nn.Module, training: bool = True, dataloader: DataLoader | None = None)
| 2427 | return 1 |
| 2428 | |
| 2429 | def _wrap_model(self, model: nn.Module, training: bool = True, dataloader: DataLoader | None = None) -> nn.Module: |
| 2430 | """Wrap `model` for distributed training if needed (DDP, FSDP, SageMaker, etc.).""" |
| 2431 | # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again |
| 2432 | if self.accelerator.unwrap_model(model, keep_torch_compile=False) is not model: |
| 2433 | return model |
| 2434 | |
| 2435 | if is_sagemaker_mp_enabled(): |
| 2436 | # Wrapping the base model twice in a DistributedModel will raise an error. |
| 2437 | if isinstance(model, smp.model.DistributedModel): |
| 2438 | return model |
| 2439 | return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) |
| 2440 | |
| 2441 | # Multi-gpu training, quantized models do not support DP |
| 2442 | if ( |
| 2443 | self.args.n_gpu > 1 |
| 2444 | and not getattr(model, "is_loaded_in_8bit", False) |
| 2445 | and not getattr(model, "is_loaded_in_4bit", False) |
| 2446 | ): |
| 2447 | model = nn.DataParallel(model) |
| 2448 | |
| 2449 | # Note: in torch.distributed mode, there's no point in wrapping the model |
| 2450 | # inside a DistributedDataParallel as we'll be under `no_grad` anyways. |
| 2451 | if not training: |
| 2452 | return model |
| 2453 | |
| 2454 | # Distributed training using PyTorch FSDP |
| 2455 | if self.is_fsdp_xla_enabled: |
| 2456 | model = wrap_model_xla_fsdp(model, self.args, self.is_fsdp_xla_v2_enabled) |
| 2457 | elif is_sagemaker_dp_enabled(): |
| 2458 | model = nn.parallel.DistributedDataParallel( |
| 2459 | model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] |
| 2460 | ) |
| 2461 | return model |
| 2462 | |
| 2463 | def _update_auto_batch_size(self, batch_size): |
| 2464 | """Free memory, reset model wrapping, and update DeepSpeed config for the new batch size when using `auto_find_batch_size`""" |
no test coverage detected