-
Notifications
You must be signed in to change notification settings - Fork 32k
Add GLPNImageProcessorFast with enhanced 4-channel support for #36978 #40472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
│ │ This commit introduces GLPNImageProcessorFast, a PyTorch-optimized image processor for GLPN models │ with enhanced multi-channel support. │ │ Key improvements: │ - Added GLPNImageProcessorFast class with native PyTorch tensor processing │ - Enhanced support for 1, 3, and 4-channel images (including RGBA) │ - Optimized preprocessing pipeline using torchvision transforms │ - Updated GLPNImageProcessor to support 4-channel inference │ - Added comprehensive tests for multi-channel image processing │ - Added proper documentation for the new processor │ │ The fast processor leverages PyTorch tensors throughout the processing pipeline, │ providing better performance and memory efficiency compared to the PIL-based approach. │ Both processors now support variable channel dimensions for improved flexibility. │ │
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, glpn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @akacmazz, thanks for contributing! Quite a few simplification possible here!
| if input_data_format is None: | ||
| # We assume that all images have the same channel dimension format. | ||
| input_data_format = infer_channel_dimension_format(images[0]) | ||
| input_data_format = infer_channel_dimension_format(images[0], num_channels=(1, 3, 4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why that would be needed if the numpy 4 channels test was passing.
| if is_torchvision_available(): | ||
| from torchvision.transforms import functional as F |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use v2 if available
|
|
||
|
|
||
| @auto_docstring | ||
| @requires(backends=("torchvision", "torch")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need
| @requires(backends=("torchvision", "torch")) |
| r""" | ||
| Constructs a fast GLPN image processor using PyTorch and TorchVision. | ||
|
|
||
| Args: | ||
| do_resize (`bool`, *optional*, defaults to `True`): | ||
| Whether to resize the image's (height, width) dimensions, rounding them down to the closest multiple of | ||
| `size_divisor`. Can be overridden by `do_resize` in `preprocess`. | ||
| size_divisor (`int`, *optional*, defaults to 32): | ||
| When `do_resize` is `True`, images are resized so their height and width are rounded down to the closest | ||
| multiple of `size_divisor`. Can be overridden by `size_divisor` in `preprocess`. | ||
| resample (`PIL.Image` resampling filter, *optional*, defaults to `PILImageResampling.BILINEAR`): | ||
| Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`. | ||
| do_rescale (`bool`, *optional*, defaults to `True`): | ||
| Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Can be | ||
| overridden by `do_rescale` in `preprocess`. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handled by auto_docstring
| def __init__( | ||
| self, | ||
| do_resize: bool = True, | ||
| size_divisor: int = 32, | ||
| resample=PILImageResampling.BILINEAR, | ||
| do_rescale: bool = True, | ||
| **kwargs, | ||
| ) -> None: | ||
| self.do_resize = do_resize | ||
| self.do_rescale = do_rescale | ||
| self.size_divisor = size_divisor | ||
| self.resample = resample | ||
| self.rescale_factor = 1 / 255 | ||
| super().__init__(**kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not how we handle kwargs in fast image processor. Also size_divisor is a custom kwarg, so let's add it to a GLPNFastImageProcessorKwargs class. look at other fast image processors to see how it's done
| """ | ||
| Resize the image, rounding the (height, width) dimensions down to the closest multiple of size_divisor. | ||
|
|
||
| If the image is of dimension (3, 260, 170) and size_divisor is 32, the image will be resized to (3, 256, 160). | ||
|
|
||
| Args: | ||
| image (`torch.Tensor`): | ||
| The image to resize. | ||
| size_divisor (`int`): | ||
| The image is resized so its height and width are rounded down to the closest multiple of | ||
| `size_divisor`. | ||
| interpolation (`F.InterpolationMode`, *optional*): | ||
| Resampling filter to use when resizing the image e.g. `F.InterpolationMode.BILINEAR`. | ||
|
|
||
| Returns: | ||
| `torch.Tensor`: The resized image. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need, self explanatory
| """ | |
| Resize the image, rounding the (height, width) dimensions down to the closest multiple of size_divisor. | |
| If the image is of dimension (3, 260, 170) and size_divisor is 32, the image will be resized to (3, 256, 160). | |
| Args: | |
| image (`torch.Tensor`): | |
| The image to resize. | |
| size_divisor (`int`): | |
| The image is resized so its height and width are rounded down to the closest multiple of | |
| `size_divisor`. | |
| interpolation (`F.InterpolationMode`, *optional*): | |
| Resampling filter to use when resizing the image e.g. `F.InterpolationMode.BILINEAR`. | |
| Returns: | |
| `torch.Tensor`: The resized image. | |
| """ |
| def _process_image( | ||
| self, | ||
| image: "torch.Tensor", | ||
| do_convert_rgb: bool = True, | ||
| input_data_format: Optional[Union[str, ChannelDimension]] = None, | ||
| device: Optional["torch.device"] = None, | ||
| ) -> "torch.Tensor": | ||
| """ | ||
| Process a single image tensor, supporting variable channel dimensions including 4-channel images. | ||
|
|
||
| Overrides the base class method to support 1, 3, and 4 channel images. | ||
| """ | ||
| if is_torch_available(): | ||
| import torch | ||
| from torchvision.transforms.functional import pil_to_tensor | ||
|
|
||
| # Convert PIL image to tensor if needed | ||
| if isinstance(image, torch.Tensor): | ||
| # Already a tensor, just ensure it's float | ||
| if image.dtype != torch.float32: | ||
| image = image.float() | ||
| elif hasattr(image, "mode") and hasattr(image, "size"): # PIL Image | ||
| image = pil_to_tensor(image).float() | ||
| else: | ||
| # Assume it's numpy array | ||
| image = torch.from_numpy(image) | ||
| if image.dtype != torch.float32: | ||
| image = image.float() | ||
|
|
||
| # Infer the channel dimension format if not provided, supporting 1, 3, and 4 channels | ||
| if input_data_format is None: | ||
| input_data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4)) | ||
|
|
||
| if input_data_format == ChannelDimension.LAST: | ||
| # We force the channel dimension to be first for torch tensors as this is what torchvision expects. | ||
| image = image.permute(2, 0, 1).contiguous() | ||
|
|
||
| # Now that we have torch tensors, we can move them to the right device | ||
| if device is not None: | ||
| image = image.to(device) | ||
|
|
||
| return image |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for that, handled by the parent class
| attrs_to_remove = [ | ||
| "crop_size", | ||
| "do_center_crop", | ||
| "do_normalize", | ||
| "image_mean", | ||
| "image_std", | ||
| "do_convert_rgb", | ||
| "size", | ||
| "input_data_format", | ||
| "device", | ||
| "return_tensors", | ||
| "disable_grouping", | ||
| "rescale_factor", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
most of these shouldn't be removed
| @auto_docstring | ||
| def preprocess( | ||
| self, | ||
| images: ImageInput, | ||
| do_resize: Optional[bool] = None, | ||
| size_divisor: Optional[int] = None, | ||
| resample=None, | ||
| do_rescale: Optional[bool] = None, | ||
| return_tensors: Optional[Union[TensorType, str]] = None, | ||
| data_format: ChannelDimension = ChannelDimension.FIRST, | ||
| input_data_format: Optional[Union[str, ChannelDimension]] = None, | ||
| **kwargs, | ||
| ) -> BatchFeature: | ||
| r""" | ||
| Preprocess the given images. | ||
|
|
||
| Args: | ||
| images (`ImageInput`): | ||
| Images to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If | ||
| passing in images with pixel values between 0 and 1, set `do_rescale=False`. | ||
| do_resize (`bool`, *optional*, defaults to `self.do_resize`): | ||
| Whether to resize the input such that the (height, width) dimensions are a multiple of `size_divisor`. | ||
| size_divisor (`int`, *optional*, defaults to `self.size_divisor`): | ||
| When `do_resize` is `True`, images are resized so their height and width are rounded down to the | ||
| closest multiple of `size_divisor`. | ||
| resample (`PIL.Image` resampling filter, *optional*, defaults to `self.resample`): | ||
| `PIL.Image` resampling filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has | ||
| an effect if `do_resize` is set to `True`. | ||
| do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): | ||
| Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). | ||
| return_tensors (`str` or `TensorType`, *optional*): | ||
| The type of tensors to return. Can be one of: | ||
| - `None`: Return a list of `torch.Tensor`. | ||
| - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. | ||
| data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): | ||
| The channel dimension format for the output image. Can be one of: | ||
| - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. | ||
| - `ChannelDimension.LAST`: image in (height, width, num_channels) format. | ||
| input_data_format (`ChannelDimension` or `str`, *optional*): | ||
| The channel dimension format for the input image. If unset, the channel dimension format is inferred | ||
| from the input image. Can be one of: | ||
| - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. | ||
| - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. | ||
| - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. | ||
| """ | ||
| do_resize = do_resize if do_resize is not None else self.do_resize | ||
| do_rescale = do_rescale if do_rescale is not None else self.do_rescale | ||
| size_divisor = size_divisor if size_divisor is not None else self.size_divisor | ||
| resample = resample if resample is not None else self.resample | ||
|
|
||
| # Convert PIL resampling to torchvision InterpolationMode | ||
| if is_torchvision_available(): | ||
| from ...image_utils import pil_torch_interpolation_mapping | ||
|
|
||
| interpolation = ( | ||
| pil_torch_interpolation_mapping[resample] | ||
| if isinstance(resample, (PILImageResampling, int)) | ||
| else resample | ||
| ) | ||
| else: | ||
| interpolation = F.InterpolationMode.BILINEAR | ||
|
|
||
| # Prepare images | ||
| images = self._prepare_image_like_inputs( | ||
| images=images, | ||
| do_convert_rgb=False, # Don't force RGB conversion to support variable channels | ||
| input_data_format=input_data_format, | ||
| ) | ||
|
|
||
| return self._preprocess( | ||
| images=images, | ||
| do_resize=do_resize, | ||
| size_divisor=size_divisor, | ||
| interpolation=interpolation, | ||
| do_rescale=do_rescale, | ||
| rescale_factor=self.rescale_factor, | ||
| return_tensors=return_tensors, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| def _preprocess( | ||
| self, | ||
| images: list["torch.Tensor"], | ||
| do_resize: bool, | ||
| size_divisor: Optional[int], | ||
| interpolation: Optional["F.InterpolationMode"], | ||
| do_rescale: bool, | ||
| rescale_factor: float, | ||
| return_tensors: Optional[Union[str, TensorType]], | ||
| **kwargs, | ||
| ) -> BatchFeature: | ||
| """ | ||
| Preprocess the images for GLPN. | ||
|
|
||
| Args: | ||
| images (`list[torch.Tensor]`): | ||
| List of images to preprocess. | ||
| do_resize (`bool`): | ||
| Whether to resize the images. | ||
| size_divisor (`int`, *optional*): | ||
| Size divisor for resizing. If None, uses self.size_divisor. | ||
| interpolation (`F.InterpolationMode`, *optional*): | ||
| Interpolation mode for resizing. | ||
| do_rescale (`bool`): | ||
| Whether to rescale pixel values to [0, 1]. | ||
| rescale_factor (`float`): | ||
| Factor to rescale pixel values by. | ||
| return_tensors (`str` or `TensorType`, *optional*): | ||
| Type of tensors to return. | ||
|
|
||
| Returns: | ||
| `BatchFeature`: Processed images in a BatchFeature. | ||
| """ | ||
| if size_divisor is None: | ||
| size_divisor = self.size_divisor | ||
|
|
||
| processed_images = [] | ||
|
|
||
| for image in images: | ||
| # Resize if needed | ||
| if do_resize: | ||
| image = self.resize(image, size_divisor=size_divisor, interpolation=interpolation) | ||
|
|
||
| # Rescale to [0, 1] if needed | ||
| if do_rescale: | ||
| image = self.rescale(image, scale=rescale_factor) | ||
|
|
||
| processed_images.append(image) | ||
|
|
||
| # Stack images into a batch if return_tensors is specified | ||
| if return_tensors: | ||
| processed_images = torch.stack(processed_images, dim=0) | ||
|
|
||
| return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't need to be overriden, as the preprocessing is quite simple. Parameters just need to be set accordingly (do_normalize=False etc.)
|
Closing this PR as #41725 is about to be merged |
What does this PR do?
This commit introduces GLPNImageProcessorFast, a PyTorch-optimized image processor for GLPN models with enhanced multi-channel support.
Key improvements:
The fast processor leverages PyTorch tensors throughout the processing pipeline, providing better performance and memory efficiency compared to the PIL-based approach. Both processors now support variable channel dimensions for improved flexibility.
Technical details:
This enhancement enables GLPN models to work seamlessly with RGBA images and other multi-channel inputs, which is particularly useful for computer vision applications involving images with transparency channels.
Fixes #36978 #36978
CONTRIBUTING.md#create-a-pull-request),
Pull Request section?
Please add a link to it if that's the case.
documentation guidelines, and
here are tips on formatting
docstrings.
Who can review?
@yonigozlan, @amyeroberts, @qubvel
This PR focuses on vision models (GLPN) and adds a new fast image processor with enhanced channel support. The changes include both the core implementation and comprehensive testing.