Postprocess the image output from tensor to `output_type`. Args: image (`torch.Tensor`): The image input, should be a pytorch tensor with shape `B x C x H x W`. output_type (`str`, *optional*, defaults to `pil`): The output ty
(
self,
image: torch.Tensor,
output_type: str = "pil",
do_denormalize: list[bool] | None = None,
)
| 1102 | return pil_images |
| 1103 | |
| 1104 | def postprocess( |
| 1105 | self, |
| 1106 | image: torch.Tensor, |
| 1107 | output_type: str = "pil", |
| 1108 | do_denormalize: list[bool] | None = None, |
| 1109 | ) -> PIL.Image.Image | np.ndarray | torch.Tensor: |
| 1110 | """ |
| 1111 | Postprocess the image output from tensor to `output_type`. |
| 1112 | |
| 1113 | Args: |
| 1114 | image (`torch.Tensor`): |
| 1115 | The image input, should be a pytorch tensor with shape `B x C x H x W`. |
| 1116 | output_type (`str`, *optional*, defaults to `pil`): |
| 1117 | The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. |
| 1118 | do_denormalize (`list[bool]`, *optional*, defaults to `None`): |
| 1119 | Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the |
| 1120 | `VaeImageProcessor` config. |
| 1121 | |
| 1122 | Returns: |
| 1123 | `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: |
| 1124 | The postprocessed image. |
| 1125 | """ |
| 1126 | if not isinstance(image, torch.Tensor): |
| 1127 | raise ValueError( |
| 1128 | f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" |
| 1129 | ) |
| 1130 | if output_type not in ["latent", "pt", "np", "pil"]: |
| 1131 | deprecation_message = ( |
| 1132 | f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " |
| 1133 | "`pil`, `np`, `pt`, `latent`" |
| 1134 | ) |
| 1135 | deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) |
| 1136 | output_type = "np" |
| 1137 | |
| 1138 | image = self._denormalize_conditionally(image, do_denormalize) |
| 1139 | |
| 1140 | image = self.pt_to_numpy(image) |
| 1141 | |
| 1142 | if output_type == "np": |
| 1143 | if image.shape[-1] == 6: |
| 1144 | image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0) |
| 1145 | else: |
| 1146 | image_depth = image[:, :, :, 3:] |
| 1147 | return image[:, :, :, :3], image_depth |
| 1148 | |
| 1149 | if output_type == "pil": |
| 1150 | return self.numpy_to_pil(image), self.numpy_to_depth(image) |
| 1151 | else: |
| 1152 | raise Exception(f"This type {output_type} is not supported") |
| 1153 | |
| 1154 | def preprocess( |
| 1155 | self, |
nothing calls this directly
no test coverage detected