MCPcopy
hub / github.com/huggingface/transformers / _prepare_input

Method _prepare_input

src/transformers/trainer.py:2185–2201  ·  view source on GitHub ↗

Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.

(self, data: torch.Tensor | Any)

Source from the content-addressed store, hash-verified

2183 return num_items_in_batch
2184
2185 def _prepare_input(self, data: torch.Tensor | Any) -> torch.Tensor | Any:
2186 """
2187 Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
2188 """
2189 if isinstance(data, Mapping):
2190 return type(data)({k: self._prepare_input(v) for k, v in data.items()})
2191 elif isinstance(data, (tuple, list)):
2192 return type(data)(self._prepare_input(v) for v in data)
2193 elif isinstance(data, torch.Tensor):
2194 kwargs = {"device": self.args.device}
2195 if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
2196 # NLP models inputs are int/uint and those get adjusted to the right dtype of the
2197 # embedding. Other models such as wav2vec2's inputs are already float and thus
2198 # may need special handling to match the dtypes of the model
2199 kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
2200 return data.to(**kwargs)
2201 return data
2202
2203 def _prepare_inputs(self, inputs: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]:
2204 """

Callers 2

_prepare_inputsMethod · 0.95
evaluation_loopMethod · 0.95

Calls 4

itemsMethod · 0.45
updateMethod · 0.45
dtypeMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected