diff --git a/docs/source/en/model_doc/glpn.md b/docs/source/en/model_doc/glpn.md index 8eb2c338a456..8081a6e0c66f 100644 --- a/docs/source/en/model_doc/glpn.md +++ b/docs/source/en/model_doc/glpn.md @@ -61,6 +61,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] GLPNImageProcessor - preprocess +## GLPNImageProcessorFast + +[[autodoc]] GLPNImageProcessorFast + - preprocess + ## GLPNModel [[autodoc]] GLPNModel diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index a145754d3209..0e1e85273cd5 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -305,6 +305,8 @@ def resize( Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`. + antialias (`bool`, *optional*, defaults to `True`): + Whether to use antialiasing. Returns: `torch.Tensor`: The resized image. diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 60af0f869bad..653f3b43a2c2 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -103,7 +103,7 @@ ("gemma3n", ("SiglipImageProcessor", "SiglipImageProcessorFast")), ("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")), - ("glpn", ("GLPNImageProcessor", None)), + ("glpn", ("GLPNImageProcessor", "GLPNImageProcessorFast")), ("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")), ("grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")), ("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")), diff --git a/src/transformers/models/glpn/__init__.py b/src/transformers/models/glpn/__init__.py index 2a5b38675c34..8d81194031c7 100644 --- a/src/transformers/models/glpn/__init__.py +++ b/src/transformers/models/glpn/__init__.py @@ -21,6 +21,7 @@ from .configuration_glpn import * from .feature_extraction_glpn import * from .image_processing_glpn import * + from .image_processing_glpn_fast import * from .modeling_glpn import * else: import sys diff --git a/src/transformers/models/glpn/image_processing_glpn.py b/src/transformers/models/glpn/image_processing_glpn.py index 35306eabc8d5..a50940840034 100644 --- a/src/transformers/models/glpn/image_processing_glpn.py +++ b/src/transformers/models/glpn/image_processing_glpn.py @@ -39,6 +39,7 @@ valid_images, validate_preprocess_arguments, ) +from ...processing_utils import ImagesKwargs from ...utils import TensorType, filter_out_non_signature_kwargs, logging, requires_backends @@ -49,6 +50,17 @@ logger = logging.get_logger(__name__) +class GLPNImageProcessorKwargs(ImagesKwargs, total=False): + """ + size_divisor (`int`, *optional*, defaults to 32): + When `do_resize` is `True`, images are resized so their height and width are rounded down to the closest + multiple of `size_divisor`. + """ + + size_divisor: int + resample: PILImageResampling + + @requires(backends=("vision",)) class GLPNImageProcessor(BaseImageProcessor): r""" @@ -66,9 +78,12 @@ class GLPNImageProcessor(BaseImageProcessor): do_rescale (`bool`, *optional*, defaults to `True`): Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Can be overridden by `do_rescale` in `preprocess`. + rescale_factor (`float`, *optional*, defaults to `1 / 255`): + The scaling factor to apply to the pixel values. Can be overridden by `rescale_factor` in `preprocess`. """ model_input_names = ["pixel_values"] + valid_kwargs = GLPNImageProcessorKwargs def __init__( self, @@ -76,12 +91,14 @@ def __init__( size_divisor: int = 32, resample=PILImageResampling.BILINEAR, do_rescale: bool = True, + rescale_factor: Optional[float] = 1 / 255, **kwargs, ) -> None: self.do_resize = do_resize self.do_rescale = do_rescale self.size_divisor = size_divisor self.resample = resample + self.rescale_factor = rescale_factor super().__init__(**kwargs) def resize( @@ -142,6 +159,7 @@ def preprocess( size_divisor: Optional[int] = None, resample=None, do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, return_tensors: Optional[Union[TensorType, str]] = None, data_format: ChannelDimension = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, @@ -181,6 +199,7 @@ def preprocess( """ do_resize = do_resize if do_resize is not None else self.do_resize do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor size_divisor = size_divisor if size_divisor is not None else self.size_divisor resample = resample if resample is not None else self.resample @@ -217,7 +236,9 @@ def preprocess( ] if do_rescale: - images = [self.rescale(image, scale=1 / 255, input_data_format=input_data_format) for image in images] + images = [ + self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) for image in images + ] images = [ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images diff --git a/src/transformers/models/glpn/image_processing_glpn_fast.py b/src/transformers/models/glpn/image_processing_glpn_fast.py new file mode 100644 index 000000000000..a906dc29c271 --- /dev/null +++ b/src/transformers/models/glpn/image_processing_glpn_fast.py @@ -0,0 +1,136 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for GLPN.""" + +from typing import Optional, Union + +import torch +from torchvision.transforms.v2 import functional as F + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images +from ...image_utils import ( + PILImageResampling, + SizeDict, +) +from ...utils import ( + TensorType, + auto_docstring, + requires_backends, +) +from .image_processing_glpn import GLPNImageProcessorKwargs + + +@auto_docstring +class GLPNImageProcessorFast(BaseImageProcessorFast): + do_resize = True + do_rescale = True + rescale_factor = 1 / 255 + resample = PILImageResampling.BILINEAR + size_divisor = 32 + valid_kwargs = GLPNImageProcessorKwargs + + def _validate_preprocess_kwargs(self, **kwargs): + # pop `do_resize` to not raise an error as `size` is not None + kwargs.pop("do_resize", None) + return super()._validate_preprocess_kwargs(**kwargs) + + def resize( + self, + image: "torch.Tensor", + size_divisor: int, + interpolation: Optional["F.InterpolationMode"] = None, + antialias: bool = True, + **kwargs, + ) -> "torch.Tensor": + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`torch.Tensor`): + Image to resize. + size (`SizeDict`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`): + `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`. + antialias (`bool`, *optional*, defaults to `True`): + Whether to use antialiasing. + + Returns: + `torch.Tensor`: The resized image. + """ + height, width = image.shape[-2:] + # Rounds the height and width down to the closest multiple of size_divisor + new_h = height // size_divisor * size_divisor + new_w = width // size_divisor * size_divisor + return super().resize( + image, SizeDict(height=new_h, width=new_w), interpolation=interpolation, antialias=antialias + ) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size_divisor: Optional[int] = None, + interpolation: Optional["F.InterpolationMode"] = None, + do_rescale: bool = True, + rescale_factor: Optional[float] = 1 / 255, + do_normalize: bool = False, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + disable_grouping: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + resample: Optional[PILImageResampling] = None, + **kwargs, + ) -> BatchFeature: + grouped_images, grouped_index = group_images_by_shape(images, disable_grouping=disable_grouping) + processed_groups = {} + + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(stacked_images, size_divisor=size_divisor, interpolation=interpolation) + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_groups[shape] = stacked_images + + processed_images = reorder_images(processed_groups, grouped_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + def post_process_depth_estimation(self, outputs, target_sizes=None): + """ + Convert raw model outputs to final depth predictions. + Mirrors slow GLPN: PyTorch interpolate w/ bicubic, align_corners=False. + """ + requires_backends(self, "torch") + predicted_depth = outputs.predicted_depth + + results = [] + target_sizes = target_sizes or [None] * predicted_depth.shape[0] + for depth, target_size in zip(predicted_depth, target_sizes): + if target_size is not None: + # Add batch and channel dimensions for interpolation + depth_4d = depth[None, None, ...] + resized = torch.nn.functional.interpolate( + depth_4d, size=target_size, mode="bicubic", align_corners=False + ) + depth = resized.squeeze(0).squeeze(0) + results.append({"predicted_depth": depth}) + + return results + + +__all__ = ["GLPNImageProcessorFast"] diff --git a/tests/models/glpn/test_image_processing_glpn.py b/tests/models/glpn/test_image_processing_glpn.py index 7f6a960755e7..396f7e9543e7 100644 --- a/tests/models/glpn/test_image_processing_glpn.py +++ b/tests/models/glpn/test_image_processing_glpn.py @@ -18,7 +18,7 @@ import numpy as np from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs @@ -31,6 +31,9 @@ from transformers import GLPNImageProcessor + if is_torchvision_available(): + from transformers import GLPNImageProcessorFast + class GLPNImageProcessingTester: def __init__( @@ -87,19 +90,32 @@ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=F torchify=torchify, ) + def prepare_depth_outputs(self): + if not is_torch_available(): + return None + depth_tensors = prepare_image_inputs( + batch_size=self.batch_size, + num_channels=1, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=True, + torchify=True, + ) + depth_tensors = [depth_tensor.squeeze(0) for depth_tensor in depth_tensors] + stacked_depth_tensors = torch.stack(depth_tensors, dim=0) + return type("DepthOutput", (), {"predicted_depth": stacked_depth_tensors}) + @require_torch @require_vision class GLPNImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = GLPNImageProcessor if is_vision_available() else None + fast_image_processing_class = GLPNImageProcessorFast if is_torchvision_available() else None def setUp(self): super().setUp() self.image_processor_tester = GLPNImageProcessingTester(self) - - @property - def image_processor_dict(self): - return self.image_processor_tester.prepare_image_processor_dict() + self.image_processor_dict = self.image_processor_tester.prepare_image_processor_dict() def test_image_processor_properties(self): image_processing = self.image_processing_class(**self.image_processor_dict) @@ -115,7 +131,6 @@ def test_call_pil(self): image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) for image in image_inputs: self.assertIsInstance(image, Image.Image) - # Test not batched input (GLPNImageProcessor doesn't support batching) encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) @@ -161,3 +176,43 @@ def test_call_numpy_4_channels(self): expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) self.assertTrue(tuple(encoded_images.shape) == (1, *expected_output_image_shape)) self.image_processing_class.num_channels = 3 + + # override as glpn image processors don't support heterogeneous batching + @require_vision + @require_torch + def test_slow_fast_equivalence_batched(self): + if not self.test_slow_image_processor or not self.test_fast_image_processor: + self.skipTest(reason="Skipping slow/fast equivalence test") + + if self.image_processing_class is None or self.fast_image_processing_class is None: + self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") + + dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) + image_processor_slow = self.image_processing_class(**self.image_processor_dict) + image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict) + + encoding_slow = image_processor_slow(dummy_images, return_tensors="pt") + encoding_fast = image_processor_fast(dummy_images, return_tensors="pt") + + self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values) + + def test_post_process_depth_equivalence(self): + # Check that both processors produce equivalent post-processed depth maps + if self.fast_image_processing_class is None: + self.skipTest("TorchVision not available") + + outputs = self.image_processor_tester.prepare_depth_outputs() + slow = self.image_processing_class(**self.image_processor_dict) + fast = self.fast_image_processing_class(**self.image_processor_dict) + + # target_sizes simulate resized inference outputs + target_sizes = [(240, 320)] * self.image_processor_tester.batch_size + processed_slow = slow.post_process_depth_estimation(outputs, target_sizes=target_sizes) + processed_fast = fast.post_process_depth_estimation(outputs, target_sizes=target_sizes) + + # Compare per-sample predicted depth tensors + for pred_slow, pred_fast in zip(processed_slow, processed_fast): + depth_slow = pred_slow["predicted_depth"] + depth_fast = pred_fast["predicted_depth"] + torch.testing.assert_close(depth_fast, depth_slow, atol=1e-1, rtol=1e-3) + self.assertLessEqual(torch.mean(torch.abs(depth_fast.float() - depth_slow.float())).item(), 5e-3)