MCPcopy
hub / github.com/huggingface/transformers / _wrap_model

Method _wrap_model

src/transformers/trainer.py:2429–2461  ·  view source on GitHub ↗

Wrap `model` for distributed training if needed (DDP, FSDP, SageMaker, etc.).

(self, model: nn.Module, training: bool = True, dataloader: DataLoader | None = None)

Source from the content-addressed store, hash-verified

2427 return 1
2428
2429 def _wrap_model(self, model: nn.Module, training: bool = True, dataloader: DataLoader | None = None) -> nn.Module:
2430 """Wrap `model` for distributed training if needed (DDP, FSDP, SageMaker, etc.)."""
2431 # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
2432 if self.accelerator.unwrap_model(model, keep_torch_compile=False) is not model:
2433 return model
2434
2435 if is_sagemaker_mp_enabled():
2436 # Wrapping the base model twice in a DistributedModel will raise an error.
2437 if isinstance(model, smp.model.DistributedModel):
2438 return model
2439 return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
2440
2441 # Multi-gpu training, quantized models do not support DP
2442 if (
2443 self.args.n_gpu > 1
2444 and not getattr(model, "is_loaded_in_8bit", False)
2445 and not getattr(model, "is_loaded_in_4bit", False)
2446 ):
2447 model = nn.DataParallel(model)
2448
2449 # Note: in torch.distributed mode, there's no point in wrapping the model
2450 # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
2451 if not training:
2452 return model
2453
2454 # Distributed training using PyTorch FSDP
2455 if self.is_fsdp_xla_enabled:
2456 model = wrap_model_xla_fsdp(model, self.args, self.is_fsdp_xla_v2_enabled)
2457 elif is_sagemaker_dp_enabled():
2458 model = nn.parallel.DistributedDataParallel(
2459 model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
2460 )
2461 return model
2462
2463 def _update_auto_batch_size(self, batch_size):
2464 """Free memory, reset model wrapping, and update DeepSpeed config for the new batch size when using `auto_find_batch_size`"""

Callers 2

_prepare_for_trainingMethod · 0.95
evaluation_loopMethod · 0.95

Calls 3

is_sagemaker_mp_enabledFunction · 0.85
wrap_model_xla_fsdpFunction · 0.85
is_sagemaker_dp_enabledFunction · 0.85

Tested by

no test coverage detected