Postprocess the image output from tensor to `output_type`. Args: image (`torch.Tensor`): The image input, should be a pytorch tensor with shape `B x C x H x W`. output_type (`str`, *optional*, defaults to `pil`): The output ty
(
self,
image: torch.Tensor,
output_type: str = "pil",
do_denormalize: list[bool] | None = None,
)
| 736 | return image |
| 737 | |
| 738 | def postprocess( |
| 739 | self, |
| 740 | image: torch.Tensor, |
| 741 | output_type: str = "pil", |
| 742 | do_denormalize: list[bool] | None = None, |
| 743 | ) -> PIL.Image.Image | np.ndarray | torch.Tensor: |
| 744 | """ |
| 745 | Postprocess the image output from tensor to `output_type`. |
| 746 | |
| 747 | Args: |
| 748 | image (`torch.Tensor`): |
| 749 | The image input, should be a pytorch tensor with shape `B x C x H x W`. |
| 750 | output_type (`str`, *optional*, defaults to `pil`): |
| 751 | The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. |
| 752 | do_denormalize (`list[bool]`, *optional*, defaults to `None`): |
| 753 | Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the |
| 754 | `VaeImageProcessor` config. |
| 755 | |
| 756 | Returns: |
| 757 | `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: |
| 758 | The postprocessed image. |
| 759 | """ |
| 760 | if not isinstance(image, torch.Tensor): |
| 761 | raise ValueError( |
| 762 | f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" |
| 763 | ) |
| 764 | if output_type not in ["latent", "pt", "np", "pil"]: |
| 765 | deprecation_message = ( |
| 766 | f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " |
| 767 | "`pil`, `np`, `pt`, `latent`" |
| 768 | ) |
| 769 | deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) |
| 770 | output_type = "np" |
| 771 | |
| 772 | if output_type == "latent": |
| 773 | return image |
| 774 | |
| 775 | image = self._denormalize_conditionally(image, do_denormalize) |
| 776 | |
| 777 | if output_type == "pt": |
| 778 | return image |
| 779 | |
| 780 | image = self.pt_to_numpy(image) |
| 781 | |
| 782 | if output_type == "np": |
| 783 | return image |
| 784 | |
| 785 | if output_type == "pil": |
| 786 | return self.numpy_to_pil(image) |
| 787 | |
| 788 | def apply_overlay( |
| 789 | self, |