Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
(self, data: torch.Tensor | Any)
| 2183 | return num_items_in_batch |
| 2184 | |
| 2185 | def _prepare_input(self, data: torch.Tensor | Any) -> torch.Tensor | Any: |
| 2186 | """ |
| 2187 | Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. |
| 2188 | """ |
| 2189 | if isinstance(data, Mapping): |
| 2190 | return type(data)({k: self._prepare_input(v) for k, v in data.items()}) |
| 2191 | elif isinstance(data, (tuple, list)): |
| 2192 | return type(data)(self._prepare_input(v) for v in data) |
| 2193 | elif isinstance(data, torch.Tensor): |
| 2194 | kwargs = {"device": self.args.device} |
| 2195 | if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)): |
| 2196 | # NLP models inputs are int/uint and those get adjusted to the right dtype of the |
| 2197 | # embedding. Other models such as wav2vec2's inputs are already float and thus |
| 2198 | # may need special handling to match the dtypes of the model |
| 2199 | kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()}) |
| 2200 | return data.to(**kwargs) |
| 2201 | return data |
| 2202 | |
| 2203 | def _prepare_inputs(self, inputs: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: |
| 2204 | """ |
no test coverage detected