Wrap model, create optimizer and scheduler, and run accelerator.prepare. Returns (model, train_dataloader).
(self, max_steps, train_dataloader, resume_from_checkpoint)
| 1562 | return epochs_trained, steps_trained_in_current_epoch |
| 1563 | |
| 1564 | def _prepare_for_training(self, max_steps, train_dataloader, resume_from_checkpoint): |
| 1565 | """Wrap model, create optimizer and scheduler, and run accelerator.prepare. Returns (model, train_dataloader).""" |
| 1566 | delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled |
| 1567 | |
| 1568 | # Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404 |
| 1569 | is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2) |
| 1570 | if is_fsdp2: |
| 1571 | delay_optimizer_creation = False |
| 1572 | |
| 1573 | # We need to reset the scheduler, as its parameters may be different on subsequent calls |
| 1574 | if self._created_lr_scheduler: |
| 1575 | self.lr_scheduler = None |
| 1576 | self._created_lr_scheduler = False |
| 1577 | |
| 1578 | if self.is_deepspeed_enabled: |
| 1579 | self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) |
| 1580 | |
| 1581 | if not delay_optimizer_creation: |
| 1582 | self.create_optimizer() |
| 1583 | |
| 1584 | # Pass `self.model_wrapped` so that `_wrap_model` can detect if the model is already |
| 1585 | # wrapped (e.g. in DataParallel) on subsequent `train()` calls and avoid double wrapping. |
| 1586 | model = self._wrap_model(self.model_wrapped) |
| 1587 | |
| 1588 | # If the model is wrapped, don't use `accelerator.prepare` |
| 1589 | # this is for unhandled cases in accelerate such as FSDP-XLA, SageMaker MP/DP, DataParallel |
| 1590 | use_accelerator_prepare = model is self.model |
| 1591 | |
| 1592 | # prepare using `accelerator` prepare |
| 1593 | if use_accelerator_prepare: |
| 1594 | if delay_optimizer_creation: |
| 1595 | # TODO: check if we can move this somewhere else |
| 1596 | if self.is_fsdp_enabled and _is_peft_model(self.model): |
| 1597 | update_fsdp_plugin_peft(self.model, self.accelerator) |
| 1598 | # we only prepare the model as we don't have an optimizer |
| 1599 | model = self.accelerator.prepare(self.model) |
| 1600 | # using the model we prepared to create the optimizer |
| 1601 | self.create_optimizer(model) |
| 1602 | self.optimizer = self.accelerator.prepare(self.optimizer) |
| 1603 | elif self.is_deepspeed_enabled and type(self.lr_scheduler).__name__ == "DummyScheduler": |
| 1604 | model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( |
| 1605 | self.model, self.optimizer, self.lr_scheduler |
| 1606 | ) |
| 1607 | else: |
| 1608 | model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) |
| 1609 | else: |
| 1610 | self.optimizer = self.accelerator.prepare(self.optimizer) |
| 1611 | |
| 1612 | # Create scheduler now that the optimizer won't change anymore |
| 1613 | self.create_scheduler(num_training_steps=max_steps) |
| 1614 | |
| 1615 | # updating self.model_wrapped |
| 1616 | self.model_wrapped = model |
| 1617 | |
| 1618 | if self.is_fsdp_enabled or self.is_fsdp_xla_enabled: |
| 1619 | # breaking convention for FSDP model |
| 1620 | # TODO: check if this is really needed |
| 1621 | self.model = self.model_wrapped = model |
no test coverage detected