(
self,
model: PreTrainedModel | nn.Module | None = None,
args: TrainingArguments | None = None,
data_collator: DataCollator | None = None,
train_dataset: "Dataset | IterableDataset | datasets.Dataset | None" = None,
eval_dataset: "Dataset | dict[str, Dataset] | datasets.Dataset | None" = None,
processing_class: PreTrainedTokenizerBase
| BaseImageProcessor
| FeatureExtractionMixin
| ProcessorMixin
| None = None,
model_init: Callable[..., PreTrainedModel] | None = None,
compute_loss_func: Callable | None = None,
compute_metrics: Callable[[EvalPrediction], dict] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None,
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
)
| 362 | # ---- Initialization & Validation ---- |
| 363 | |
| 364 | def __init__( |
| 365 | self, |
| 366 | model: PreTrainedModel | nn.Module | None = None, |
| 367 | args: TrainingArguments | None = None, |
| 368 | data_collator: DataCollator | None = None, |
| 369 | train_dataset: "Dataset | IterableDataset | datasets.Dataset | None" = None, |
| 370 | eval_dataset: "Dataset | dict[str, Dataset] | datasets.Dataset | None" = None, |
| 371 | processing_class: PreTrainedTokenizerBase |
| 372 | | BaseImageProcessor |
| 373 | | FeatureExtractionMixin |
| 374 | | ProcessorMixin |
| 375 | | None = None, |
| 376 | model_init: Callable[..., PreTrainedModel] | None = None, |
| 377 | compute_loss_func: Callable | None = None, |
| 378 | compute_metrics: Callable[[EvalPrediction], dict] | None = None, |
| 379 | callbacks: list[TrainerCallback] | None = None, |
| 380 | optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), |
| 381 | optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None, |
| 382 | preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, |
| 383 | ): |
| 384 | # Init flow: |
| 385 | # 1. Args & seed – defaults, determinism |
| 386 | # 2. Accelerator & logging – accelerator, memory tracker, log level, device setup |
| 387 | # 3. Model resolution – model / model_init, Liger Kernel, quantization checks |
| 388 | # 4. Distributed strategy – model-parallel, FSDP, SageMaker MP flags |
| 389 | # 5. Device placement – move model to device, model wrapping |
| 390 | # 6. Model introspection – loss kwargs, label names, label smoother |
| 391 | # 7. Store init arguments – data, callables, optimizer, scheduler, validation |
| 392 | # 8. Callbacks – reporting integrations, JIT checkpoint, progress bar |
| 393 | # 9. Hub & output – repo init, output directory |
| 394 | # 10. Training state – TrainerState, TrainerControl, internal bookkeeping |
| 395 | # 11. Finalize – use_cache, XLA FSDPv2 mesh, memory tracker stop |
| 396 | |
| 397 | # ---- 1. Args & seed -------------------------------------------------------- |
| 398 | if args is None: |
| 399 | output_dir = "tmp_trainer" |
| 400 | logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") |
| 401 | args = TrainingArguments(output_dir=output_dir) |
| 402 | self.args = args |
| 403 | # Seed must be set before instantiating the model when using model_init |
| 404 | enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) |
| 405 | |
| 406 | # ---- 2. Accelerator & logging ---------------------------------------------- |
| 407 | # `create_accelerator_and_postprocess` reads self.model and self.args, |
| 408 | # and may set self.deepspeed — store temporary refs before calling it. |
| 409 | self.deepspeed = None |
| 410 | self.model = model |
| 411 | self.create_accelerator_and_postprocess() |
| 412 | |
| 413 | self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) |
| 414 | self._memory_tracker.start() |
| 415 | |
| 416 | log_level = args.get_process_log_level() |
| 417 | logging.set_verbosity(log_level) |
| 418 | |
| 419 | args._setup_devices # force device and distributed setup init explicitly |
| 420 | |
| 421 | # ---- 3. Model resolution ---------------------------------------------------- |
nothing calls this directly
no test coverage detected