MCPcopy
hub / github.com/huggingface/transformers / prediction_step

Method prediction_step

src/transformers/trainer.py:2904–3007  ·  view source on GitHub ↗

Perform an evaluation step on `model` using `inputs`. Subclass and override to inject custom behavior. Args: model (`nn.Module`): The model to evaluate. inputs (`dict[str, torch.Tensor | Any]`): The inputs and targets

(
        self,
        model: nn.Module,
        inputs: dict[str, torch.Tensor | Any],
        prediction_loss_only: bool,
        ignore_keys: list[str] | None = None,
    )

Source from the content-addressed store, hash-verified

2902 return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
2903
2904 def prediction_step(
2905 self,
2906 model: nn.Module,
2907 inputs: dict[str, torch.Tensor | Any],
2908 prediction_loss_only: bool,
2909 ignore_keys: list[str] | None = None,
2910 ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
2911 """
2912 Perform an evaluation step on `model` using `inputs`.
2913
2914 Subclass and override to inject custom behavior.
2915
2916 Args:
2917 model (`nn.Module`):
2918 The model to evaluate.
2919 inputs (`dict[str, torch.Tensor | Any]`):
2920 The inputs and targets of the model.
2921
2922 The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
2923 argument `labels`. Check your model's documentation for all accepted arguments.
2924 prediction_loss_only (`bool`):
2925 Whether or not to return the loss only.
2926 ignore_keys (`list[str]`, *optional*):
2927 A list of keys in the output of your model (if it is a dictionary) that should be ignored when
2928 gathering predictions.
2929
2930 Return:
2931 tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
2932 logits and labels (each being optional).
2933 """
2934 has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
2935 # For CLIP-like models capable of returning loss values.
2936 # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
2937 # is `True` in `model.forward`.
2938 return_loss = inputs.get("return_loss")
2939 if return_loss is None:
2940 return_loss = self.can_return_loss
2941 loss_without_labels = len(self.label_names) == 0 and return_loss
2942
2943 inputs = self._prepare_inputs(inputs)
2944 if ignore_keys is None:
2945 if hasattr(self.model, "config"):
2946 ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", ["past_key_values"])
2947 else:
2948 ignore_keys = []
2949
2950 # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
2951 if has_labels or loss_without_labels:
2952 labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
2953 if len(labels) == 1:
2954 labels = labels[0]
2955 else:
2956 labels = None
2957
2958 with torch.no_grad():
2959 if is_sagemaker_mp_enabled():
2960 raw_outputs = smp_forward_only(model, inputs)
2961 if has_labels or loss_without_labels:

Callers 2

evaluation_loopMethod · 0.95

Calls 12

_prepare_inputsMethod · 0.95
compute_lossMethod · 0.95
nested_detachFunction · 0.85
is_sagemaker_mp_enabledFunction · 0.85
smp_forward_onlyFunction · 0.85
smp_nested_concatFunction · 0.85
detachMethod · 0.80
getMethod · 0.45
itemsMethod · 0.45
meanMethod · 0.45

Tested by 1