Skip to content

Conversation

@akacmazz
Copy link

What does this PR do?

This commit introduces GLPNImageProcessorFast, a PyTorch-optimized image processor for GLPN models with enhanced multi-channel support.

Key improvements:

  • Added GLPNImageProcessorFast class with native PyTorch tensor processing
  • Enhanced support for 1, 3, and 4-channel images (including RGBA)
  • Optimized preprocessing pipeline using torchvision transforms
  • Updated GLPNImageProcessor to support 4-channel inference
  • Added comprehensive tests for multi-channel image processing
  • Added proper documentation for the new processor

The fast processor leverages PyTorch tensors throughout the processing pipeline, providing better performance and memory efficiency compared to the PIL-based approach. Both processors now support variable channel dimensions for improved flexibility.

Technical details:

  • Uses torchvision.transforms for efficient tensor-based preprocessing
  • Implements proper channel dimension handling with infer_channel_dimension_format(num_channels=(1,3, 4))
  • Maintains API compatibility with existing GLPNImageProcessor
  • Provides significant performance improvements for PyTorch workflows

This enhancement enables GLPN models to work seamlessly with RGBA images and other multi-channel inputs, which is particularly useful for computer vision applications involving images with transparency channels.

Fixes #36978 #36978

Who can review?

@yonigozlan, @amyeroberts, @qubvel

This PR focuses on vision models (GLPN) and adds a new fast image processor with enhanced channel support. The changes include both the core implementation and comprehensive testing.

│
│   This commit introduces GLPNImageProcessorFast, a PyTorch-optimized image processor for GLPN models
│   with enhanced multi-channel support.
│
│   Key improvements:
│   - Added GLPNImageProcessorFast class with native PyTorch tensor processing
│   - Enhanced support for 1, 3, and 4-channel images (including RGBA)
│   - Optimized preprocessing pipeline using torchvision transforms
│   - Updated GLPNImageProcessor to support 4-channel inference
│   - Added comprehensive tests for multi-channel image processing
│   - Added proper documentation for the new processor
│
│   The fast processor leverages PyTorch tensors throughout the processing pipeline,
│   providing better performance and memory efficiency compared to the PIL-based approach.
│   Both processors now support variable channel dimensions for improved flexibility.
│
│
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, glpn

Copy link
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @akacmazz, thanks for contributing! Quite a few simplification possible here!

if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
input_data_format = infer_channel_dimension_format(images[0], num_channels=(1, 3, 4))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why that would be needed if the numpy 4 channels test was passing.

Comment on lines +32 to +33
if is_torchvision_available():
from torchvision.transforms import functional as F
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use v2 if available



@auto_docstring
@requires(backends=("torchvision", "torch"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need

Suggested change
@requires(backends=("torchvision", "torch"))

Comment on lines +42 to +57
r"""
Constructs a fast GLPN image processor using PyTorch and TorchVision.

Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions, rounding them down to the closest multiple of
`size_divisor`. Can be overridden by `do_resize` in `preprocess`.
size_divisor (`int`, *optional*, defaults to 32):
When `do_resize` is `True`, images are resized so their height and width are rounded down to the closest
multiple of `size_divisor`. Can be overridden by `size_divisor` in `preprocess`.
resample (`PIL.Image` resampling filter, *optional*, defaults to `PILImageResampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Can be
overridden by `do_rescale` in `preprocess`.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handled by auto_docstring

Comment on lines +61 to +74
def __init__(
self,
do_resize: bool = True,
size_divisor: int = 32,
resample=PILImageResampling.BILINEAR,
do_rescale: bool = True,
**kwargs,
) -> None:
self.do_resize = do_resize
self.do_rescale = do_rescale
self.size_divisor = size_divisor
self.resample = resample
self.rescale_factor = 1 / 255
super().__init__(**kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not how we handle kwargs in fast image processor. Also size_divisor is a custom kwarg, so let's add it to a GLPNFastImageProcessorKwargs class. look at other fast image processors to see how it's done

Comment on lines +83 to +99
"""
Resize the image, rounding the (height, width) dimensions down to the closest multiple of size_divisor.

If the image is of dimension (3, 260, 170) and size_divisor is 32, the image will be resized to (3, 256, 160).

Args:
image (`torch.Tensor`):
The image to resize.
size_divisor (`int`):
The image is resized so its height and width are rounded down to the closest multiple of
`size_divisor`.
interpolation (`F.InterpolationMode`, *optional*):
Resampling filter to use when resizing the image e.g. `F.InterpolationMode.BILINEAR`.

Returns:
`torch.Tensor`: The resized image.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need, self explanatory

Suggested change
"""
Resize the image, rounding the (height, width) dimensions down to the closest multiple of size_divisor.
If the image is of dimension (3, 260, 170) and size_divisor is 32, the image will be resized to (3, 256, 160).
Args:
image (`torch.Tensor`):
The image to resize.
size_divisor (`int`):
The image is resized so its height and width are rounded down to the closest multiple of
`size_divisor`.
interpolation (`F.InterpolationMode`, *optional*):
Resampling filter to use when resizing the image e.g. `F.InterpolationMode.BILINEAR`.
Returns:
`torch.Tensor`: The resized image.
"""

Comment on lines +114 to +155
def _process_image(
self,
image: "torch.Tensor",
do_convert_rgb: bool = True,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
device: Optional["torch.device"] = None,
) -> "torch.Tensor":
"""
Process a single image tensor, supporting variable channel dimensions including 4-channel images.

Overrides the base class method to support 1, 3, and 4 channel images.
"""
if is_torch_available():
import torch
from torchvision.transforms.functional import pil_to_tensor

# Convert PIL image to tensor if needed
if isinstance(image, torch.Tensor):
# Already a tensor, just ensure it's float
if image.dtype != torch.float32:
image = image.float()
elif hasattr(image, "mode") and hasattr(image, "size"): # PIL Image
image = pil_to_tensor(image).float()
else:
# Assume it's numpy array
image = torch.from_numpy(image)
if image.dtype != torch.float32:
image = image.float()

# Infer the channel dimension format if not provided, supporting 1, 3, and 4 channels
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image, num_channels=(1, 3, 4))

if input_data_format == ChannelDimension.LAST:
# We force the channel dimension to be first for torch tensors as this is what torchvision expects.
image = image.permute(2, 0, 1).contiguous()

# Now that we have torch tensors, we can move them to the right device
if device is not None:
image = image.to(device)

return image
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for that, handled by the parent class

Comment on lines +166 to +179
attrs_to_remove = [
"crop_size",
"do_center_crop",
"do_normalize",
"image_mean",
"image_std",
"do_convert_rgb",
"size",
"input_data_format",
"device",
"return_tensors",
"disable_grouping",
"rescale_factor",
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most of these shouldn't be removed

Comment on lines +186 to +319
@auto_docstring
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size_divisor: Optional[int] = None,
resample=None,
do_rescale: Optional[bool] = None,
return_tensors: Optional[Union[TensorType, str]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> BatchFeature:
r"""
Preprocess the given images.

Args:
images (`ImageInput`):
Images to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the input such that the (height, width) dimensions are a multiple of `size_divisor`.
size_divisor (`int`, *optional*, defaults to `self.size_divisor`):
When `do_resize` is `True`, images are resized so their height and width are rounded down to the
closest multiple of `size_divisor`.
resample (`PIL.Image` resampling filter, *optional*, defaults to `self.resample`):
`PIL.Image` resampling filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.).
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- `None`: Return a list of `torch.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
size_divisor = size_divisor if size_divisor is not None else self.size_divisor
resample = resample if resample is not None else self.resample

# Convert PIL resampling to torchvision InterpolationMode
if is_torchvision_available():
from ...image_utils import pil_torch_interpolation_mapping

interpolation = (
pil_torch_interpolation_mapping[resample]
if isinstance(resample, (PILImageResampling, int))
else resample
)
else:
interpolation = F.InterpolationMode.BILINEAR

# Prepare images
images = self._prepare_image_like_inputs(
images=images,
do_convert_rgb=False, # Don't force RGB conversion to support variable channels
input_data_format=input_data_format,
)

return self._preprocess(
images=images,
do_resize=do_resize,
size_divisor=size_divisor,
interpolation=interpolation,
do_rescale=do_rescale,
rescale_factor=self.rescale_factor,
return_tensors=return_tensors,
**kwargs,
)

def _preprocess(
self,
images: list["torch.Tensor"],
do_resize: bool,
size_divisor: Optional[int],
interpolation: Optional["F.InterpolationMode"],
do_rescale: bool,
rescale_factor: float,
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
"""
Preprocess the images for GLPN.

Args:
images (`list[torch.Tensor]`):
List of images to preprocess.
do_resize (`bool`):
Whether to resize the images.
size_divisor (`int`, *optional*):
Size divisor for resizing. If None, uses self.size_divisor.
interpolation (`F.InterpolationMode`, *optional*):
Interpolation mode for resizing.
do_rescale (`bool`):
Whether to rescale pixel values to [0, 1].
rescale_factor (`float`):
Factor to rescale pixel values by.
return_tensors (`str` or `TensorType`, *optional*):
Type of tensors to return.

Returns:
`BatchFeature`: Processed images in a BatchFeature.
"""
if size_divisor is None:
size_divisor = self.size_divisor

processed_images = []

for image in images:
# Resize if needed
if do_resize:
image = self.resize(image, size_divisor=size_divisor, interpolation=interpolation)

# Rescale to [0, 1] if needed
if do_rescale:
image = self.rescale(image, scale=rescale_factor)

processed_images.append(image)

# Stack images into a batch if return_tensors is specified
if return_tensors:
processed_images = torch.stack(processed_images, dim=0)

return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't need to be overriden, as the preprocessing is quite simple. Parameters just need to be set accordingly (do_normalize=False etc.)

@yonigozlan
Copy link
Member

Closing this PR as #41725 is about to be merged

@yonigozlan yonigozlan closed this Nov 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Contributions Welcome] Add Fast Image Processors

2 participants