Move the model to the specified device, re-tying weights on XLA if needed.
(self, model: nn.Module, device: torch.device)
| 4429 | return self.args.process_index == 0 |
| 4430 | |
| 4431 | def _move_model_to_device(self, model: nn.Module, device: torch.device) -> None: |
| 4432 | """Move the model to the specified device, re-tying weights on XLA if needed.""" |
| 4433 | if getattr(model, "hf_device_map", None) is not None: |
| 4434 | logger.warning( |
| 4435 | "The model is already on multiple devices. Skipping the move to device specified in `args`." |
| 4436 | ) |
| 4437 | return |
| 4438 | model = model.to(device) |
| 4439 | # Moving a model to an XLA device disconnects the tied weights, so we have to retie them. |
| 4440 | if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"): |
| 4441 | model.tie_weights() |
no test coverage detected