r""" Convert a NumPy image to a PyTorch tensor. Args: images (`np.ndarray`): The NumPy image array to convert to PyTorch format. Returns: `torch.Tensor`: A PyTorch tensor representation of the images.
(images: np.ndarray)
| 170 | |
| 171 | @staticmethod |
| 172 | def numpy_to_pt(images: np.ndarray) -> torch.Tensor: |
| 173 | r""" |
| 174 | Convert a NumPy image to a PyTorch tensor. |
| 175 | |
| 176 | Args: |
| 177 | images (`np.ndarray`): |
| 178 | The NumPy image array to convert to PyTorch format. |
| 179 | |
| 180 | Returns: |
| 181 | `torch.Tensor`: |
| 182 | A PyTorch tensor representation of the images. |
| 183 | """ |
| 184 | if images.ndim == 3: |
| 185 | images = images[..., None] |
| 186 | |
| 187 | images = torch.from_numpy(images.transpose(0, 3, 1, 2)) |
| 188 | return images |
| 189 | |
| 190 | @staticmethod |
| 191 | def pt_to_numpy(images: torch.Tensor) -> np.ndarray: |