diff --git a/docs/source/en/internal/image_processing_utils.mdx b/docs/source/en/internal/image_processing_utils.mdx index f1658e55525d..831458bedab1 100644 --- a/docs/source/en/internal/image_processing_utils.mdx +++ b/docs/source/en/internal/image_processing_utils.mdx @@ -29,6 +29,8 @@ Most of those are only useful if you are studying the code of the image processo [[autodoc]] image_transforms.normalize +[[autodoc]] image_transforms.pad + [[autodoc]] image_transforms.rgb_to_id [[autodoc]] image_transforms.rescale diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index d8d1d60935d7..1909d04e2a67 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -14,11 +14,11 @@ # limitations under the License. import warnings -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Tuple, Union import numpy as np -from transformers.utils import TensorType +from transformers.utils import ExplicitEnum, TensorType from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available @@ -38,13 +38,14 @@ ) -if TYPE_CHECKING: - if is_torch_available(): - import torch - if is_tf_available(): - import tensorflow as tf - if is_flax_available(): - import jax.numpy as jnp +if is_torch_available(): + import torch + +if is_tf_available(): + import tensorflow as tf + +if is_flax_available(): + import jax.numpy as jnp def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDimension, str]) -> np.ndarray: @@ -568,3 +569,100 @@ def id_to_rgb(id_map): color.append(id_map % 256) id_map //= 256 return color + + +class PaddingMode(ExplicitEnum): + """ + Enum class for the different padding modes to use when padding images. + """ + + CONSTANT = "constant" + REFLECT = "reflect" + REPLICATE = "replicate" + SYMMETRIC = "symmetric" + + +def pad( + image: np.ndarray, + padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]], + mode: PaddingMode = PaddingMode.CONSTANT, + constant_values: Union[float, Iterable[float]] = 0.0, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> np.ndarray: + """ + Pads the `image` with the specified (height, width) `padding` and `mode`. + + Args: + image (`np.ndarray`): + The image to pad. + padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`): + Padding to apply to the edges of the height, width axes. Can be one of three formats: + - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis. + - `((before, after),)` yields same before and after pad for height and width. + - `(pad,)` or int is a shortcut for before = after = pad width for all axes. + mode (`PaddingMode`): + The padding mode to use. Can be one of: + - `"constant"`: pads with a constant value. + - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the + vector along each axis. + - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis. + - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array. + constant_values (`float` or `Iterable[float]`, *optional*): + The value to use for the padding if `mode` is `"constant"`. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + + Returns: + `np.ndarray`: The padded image. + + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + def _expand_for_data_format(values): + """ + Convert values to be in the format expected by np.pad based on the data format. + """ + if isinstance(values, (int, float)): + values = ((values, values), (values, values)) + elif isinstance(values, tuple) and len(values) == 1: + values = ((values[0], values[0]), (values[0], values[0])) + elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int): + values = (values, values) + elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple): + values = values + else: + raise ValueError(f"Unsupported format: {values}") + + # add 0 for channel dimension + values = ((0, 0), *values) if input_data_format == ChannelDimension.FIRST else (*values, (0, 0)) + + # Add additional padding if there's a batch dimension + values = (0, *values) if image.ndim == 4 else values + return values + + padding = _expand_for_data_format(padding) + + if mode == PaddingMode.CONSTANT: + constant_values = _expand_for_data_format(constant_values) + image = np.pad(image, padding, mode="constant", constant_values=constant_values) + elif mode == PaddingMode.REFLECT: + image = np.pad(image, padding, mode="reflect") + elif mode == PaddingMode.REPLICATE: + image = np.pad(image, padding, mode="edge") + elif mode == PaddingMode.SYMMETRIC: + image = np.pad(image, padding, mode="symmetric") + else: + raise ValueError(f"Invalid padding mode: {mode}") + + image = to_channel_dimension_format(image, data_format) if data_format is not None else image + return image diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index d0b7c9ade137..618181b004d5 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -41,6 +41,7 @@ get_resize_output_image_size, id_to_rgb, normalize, + pad, resize, rgb_to_id, to_channel_dimension_format, @@ -289,3 +290,127 @@ def test_id_to_rgb(self): ] ) self.assertTrue(np.allclose(id_to_rgb(id_array), color)) + + def test_pad(self): + # fmt: off + image = np.array([[ + [0, 1], + [2, 3], + ]]) + # fmt: on + + # Test that exception is raised if unknown padding mode is specified + with self.assertRaises(ValueError): + pad(image, 10, mode="unknown") + + # Test that exception is raised if invalid padding is specified + with self.assertRaises(ValueError): + # Cannot pad on channel dimension + pad(image, (5, 10, 10)) + + # Test image is padded equally on all sides is padding is an int + # fmt: off + expected_image = np.array([ + [[0, 0, 0, 0], + [0, 0, 1, 0], + [0, 2, 3, 0], + [0, 0, 0, 0]], + ]) + # fmt: on + self.assertTrue(np.allclose(expected_image, pad(image, 1))) + + # Test the left and right of each axis is padded (pad_left, pad_right) + # fmt: off + expected_image = np.array( + [[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 2, 3, 0], + [0, 0, 0, 0, 0]]) + # fmt: on + self.assertTrue(np.allclose(expected_image, pad(image, (2, 1)))) + + # Test only one axis is padded (pad_left, pad_right) + # fmt: off + expected_image = np.array([[ + [9, 9], + [9, 9], + [0, 1], + [2, 3], + [9, 9] + ]]) + # fmt: on + self.assertTrue(np.allclose(expected_image, pad(image, ((2, 1), (0, 0)), constant_values=9))) + + # Test padding with a constant value + # fmt: off + expected_image = np.array([[ + [8, 8, 0, 1, 9], + [8, 8, 2, 3, 9], + [8, 8, 7, 7, 9], + [8, 8, 7, 7, 9] + ]]) + # fmt: on + self.assertTrue(np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), constant_values=((6, 7), (8, 9))))) + + # fmt: off + image = np.array([[ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + ]]) + # fmt: on + + # Test padding with PaddingMode.REFLECT + # fmt: off + expected_image = np.array([[ + [2, 1, 0, 1, 2, 1], + [5, 4, 3, 4, 5, 4], + [8, 7, 6, 7, 8, 7], + [5, 4, 3, 4, 5, 4], + [2, 1, 0, 1, 2, 1], + ]]) + # fmt: on + self.assertTrue(np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect"))) + + # Test padding with PaddingMode.REPLICATE + # fmt: off + expected_image = np.array([[ + [0, 0, 0, 1, 2, 2], + [3, 3, 3, 4, 5, 5], + [6, 6, 6, 7, 8, 8], + [6, 6, 6, 7, 8, 8], + [6, 6, 6, 7, 8, 8], + ]]) + # fmt: on + self.assertTrue(np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="replicate"))) + + # Test padding with PaddingMode.SYMMETRIC + # fmt: off + expected_image = np.array([[ + [1, 0, 0, 1, 2, 2], + [4, 3, 3, 4, 5, 5], + [7, 6, 6, 7, 8, 8], + [7, 6, 6, 7, 8, 8], + [4, 3, 3, 4, 5, 5], + ]]) + # fmt: on + self.assertTrue(np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="symmetric"))) + + # Test we can specify the output data format + # Test padding with PaddingMode.REFLECT + # fmt: off + image = np.array([[ + [0, 1], + [2, 3], + ]]) + expected_image = np.array([ + [[0], [1], [0], [1], [0]], + [[2], [3], [2], [3], [2]], + [[0], [1], [0], [1], [0]], + [[2], [3], [2], [3], [2]] + ]) + # fmt: on + self.assertTrue( + np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last")) + )