MCPcopy
hub / github.com/huggingface/transformers / __init__

Method __init__

src/transformers/trainer.py:364–618  ·  view source on GitHub ↗
(
        self,
        model: PreTrainedModel | nn.Module | None = None,
        args: TrainingArguments | None = None,
        data_collator: DataCollator | None = None,
        train_dataset: "Dataset | IterableDataset | datasets.Dataset | None" = None,
        eval_dataset: "Dataset | dict[str, Dataset] | datasets.Dataset | None" = None,
        processing_class: PreTrainedTokenizerBase
        | BaseImageProcessor
        | FeatureExtractionMixin
        | ProcessorMixin
        | None = None,
        model_init: Callable[..., PreTrainedModel] | None = None,
        compute_loss_func: Callable | None = None,
        compute_metrics: Callable[[EvalPrediction], dict] | None = None,
        callbacks: list[TrainerCallback] | None = None,
        optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
        optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None,
        preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
    )

Source from the content-addressed store, hash-verified

362 # ---- Initialization & Validation ----
363
364 def __init__(
365 self,
366 model: PreTrainedModel | nn.Module | None = None,
367 args: TrainingArguments | None = None,
368 data_collator: DataCollator | None = None,
369 train_dataset: "Dataset | IterableDataset | datasets.Dataset | None" = None,
370 eval_dataset: "Dataset | dict[str, Dataset] | datasets.Dataset | None" = None,
371 processing_class: PreTrainedTokenizerBase
372 | BaseImageProcessor
373 | FeatureExtractionMixin
374 | ProcessorMixin
375 | None = None,
376 model_init: Callable[..., PreTrainedModel] | None = None,
377 compute_loss_func: Callable | None = None,
378 compute_metrics: Callable[[EvalPrediction], dict] | None = None,
379 callbacks: list[TrainerCallback] | None = None,
380 optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
381 optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None,
382 preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
383 ):
384 # Init flow:
385 # 1. Args & seed – defaults, determinism
386 # 2. Accelerator & logging – accelerator, memory tracker, log level, device setup
387 # 3. Model resolution – model / model_init, Liger Kernel, quantization checks
388 # 4. Distributed strategy – model-parallel, FSDP, SageMaker MP flags
389 # 5. Device placement – move model to device, model wrapping
390 # 6. Model introspection – loss kwargs, label names, label smoother
391 # 7. Store init arguments – data, callables, optimizer, scheduler, validation
392 # 8. Callbacks – reporting integrations, JIT checkpoint, progress bar
393 # 9. Hub & output – repo init, output directory
394 # 10. Training state – TrainerState, TrainerControl, internal bookkeeping
395 # 11. Finalize – use_cache, XLA FSDPv2 mesh, memory tracker stop
396
397 # ---- 1. Args & seed --------------------------------------------------------
398 if args is None:
399 output_dir = "tmp_trainer"
400 logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
401 args = TrainingArguments(output_dir=output_dir)
402 self.args = args
403 # Seed must be set before instantiating the model when using model_init
404 enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
405
406 # ---- 2. Accelerator & logging ----------------------------------------------
407 # `create_accelerator_and_postprocess` reads self.model and self.args,
408 # and may set self.deepspeed — store temporary refs before calling it.
409 self.deepspeed = None
410 self.model = model
411 self.create_accelerator_and_postprocess()
412
413 self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
414 self._memory_tracker.start()
415
416 log_level = args.get_process_log_level()
417 logging.set_verbosity(log_level)
418
419 args._setup_devices # force device and distributed setup init explicitly
420
421 # ---- 3. Model resolution ----------------------------------------------------

Callers

nothing calls this directly

Calls 15

get_process_log_levelMethod · 0.95
call_model_initMethod · 0.95
_move_model_to_deviceMethod · 0.95
_validate_argsMethod · 0.95
set_trainerMethod · 0.95
add_callbackMethod · 0.95
init_hf_repoMethod · 0.95
is_local_process_zeroMethod · 0.95
is_world_process_zeroMethod · 0.95
TrainingArgumentsClass · 0.85
enable_full_determinismFunction · 0.85

Tested by

no test coverage detected