diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index d9e2e87458c4..68fba5c4af4e 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -48,7 +48,11 @@ import jax.numpy as jnp -def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDimension, str]) -> np.ndarray: +def to_channel_dimension_format( + image: np.ndarray, + channel_dim: Union[ChannelDimension, str], + input_channel_dim: Optional[Union[ChannelDimension, str]] = None, +) -> np.ndarray: """ Converts `image` to the channel dimension format specified by `channel_dim`. @@ -64,9 +68,11 @@ def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDim if not isinstance(image, np.ndarray): raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}") - current_channel_dim = infer_channel_dimension_format(image) + if input_channel_dim is None: + input_channel_dim = infer_channel_dimension_format(image) + target_channel_dim = ChannelDimension(channel_dim) - if current_channel_dim == target_channel_dim: + if input_channel_dim == target_channel_dim: return image if target_channel_dim == ChannelDimension.FIRST: @@ -152,6 +158,7 @@ def to_pil_image( return PIL.Image.fromarray(image) +# Logic adapted from torchvision resizing logic: https://github.com/pytorch/vision/blob/511924c1ced4ce0461197e5caa64ce5b9e558aab/torchvision/transforms/functional.py#L366 def get_resize_output_image_size( input_image: np.ndarray, size: Union[int, Tuple[int, int], List[int], Tuple[int]], @@ -202,9 +209,6 @@ def get_resize_output_image_size( short, long = (width, height) if width <= height else (height, width) requested_new_short = size - if short == requested_new_short: - return (height, width) - new_short, new_long = requested_new_short, int(requested_new_short * long / short) if max_size is not None: @@ -271,7 +275,10 @@ def resize( # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image # so we need to add it back if necessary. resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image - resized_image = to_channel_dimension_format(resized_image, data_format) + # The image is always in channels last format after converting from a PIL image + resized_image = to_channel_dimension_format( + resized_image, data_format, input_channel_dim=ChannelDimension.LAST + ) return resized_image diff --git a/src/transformers/models/donut/image_processing_donut.py b/src/transformers/models/donut/image_processing_donut.py index a6434bccde19..5a6b91396908 100644 --- a/src/transformers/models/donut/image_processing_donut.py +++ b/src/transformers/models/donut/image_processing_donut.py @@ -210,7 +210,8 @@ def thumbnail( **kwargs ) -> np.ndarray: """ - Resize the image to the specified size using thumbnail method. + Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any + corresponding dimension of the specified size. Args: image (`np.ndarray`): @@ -222,8 +223,24 @@ def thumbnail( data_format (`Optional[Union[str, ChannelDimension]]`, *optional*): The data format of the output image. If unset, the same format as the input image is used. """ - output_size = (size["height"], size["width"]) - return resize(image, size=output_size, resample=resample, reducing_gap=2.0, data_format=data_format, **kwargs) + input_height, input_width = get_image_size(image) + output_height, output_width = size["height"], size["width"] + + # We always resize to the smallest of either the input or output size. + height = min(input_height, output_height) + width = min(input_width, output_width) + + if height == input_height and width == input_width: + return image + + if input_height > input_width: + width = int(input_width * height / input_height) + elif input_width > input_height: + height = int(input_height * width / input_width) + + return resize( + image, size=(height, width), resample=resample, reducing_gap=2.0, data_format=data_format, **kwargs + ) def resize( self, @@ -250,7 +267,8 @@ def resize( size = get_size_dict(size) shortest_edge = min(size["height"], size["width"]) output_size = get_resize_output_image_size(image, size=shortest_edge, default_to_square=False) - return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) + resized_image = resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) + return resized_image def rescale( self, @@ -403,7 +421,7 @@ def preprocess( images = [to_numpy_array(image) for image in images] if do_align_long_axis: - images = [self.align_long_axis(image) for image in images] + images = [self.align_long_axis(image, size=size) for image in images] if do_resize: images = [self.resize(image=image, size=size, resample=resample) for image in images] diff --git a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py index 279614371bf8..6228cb51fd5a 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py @@ -836,7 +836,7 @@ def test_inference_docvqa(self): expected_shape = torch.Size([1, 1, 57532]) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([24.2731, -6.4522, 32.4130]).to(torch_device) + expected_slice = torch.tensor([24.3873, -6.4491, 32.5394]).to(torch_device) self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4)) # step 2: generation @@ -872,7 +872,7 @@ def test_inference_docvqa(self): self.assertEqual(len(outputs.scores), 11) self.assertTrue( torch.allclose( - outputs.scores[0][0, :3], torch.tensor([5.3153, -3.5276, 13.4781], device=torch_device), atol=1e-4 + outputs.scores[0][0, :3], torch.tensor([5.6019, -3.5070, 13.7123], device=torch_device), atol=1e-4 ) ) diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index 89a6b135bf0c..f22657f63752 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -184,6 +184,25 @@ def test_get_resize_output_image_size(self): image = np.random.randint(0, 256, (3, 50, 40)) self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17)) + # Test correct channel dimension is returned if output size if height == 3 + # Defaults to input format - channels first + image = np.random.randint(0, 256, (3, 18, 97)) + resized_image = resize(image, (3, 20)) + self.assertEqual(resized_image.shape, (3, 3, 20)) + + # Defaults to input format - channels last + image = np.random.randint(0, 256, (18, 97, 3)) + resized_image = resize(image, (3, 20)) + self.assertEqual(resized_image.shape, (3, 20, 3)) + + image = np.random.randint(0, 256, (3, 18, 97)) + resized_image = resize(image, (3, 20), data_format="channels_last") + self.assertEqual(resized_image.shape, (3, 20, 3)) + + image = np.random.randint(0, 256, (18, 97, 3)) + resized_image = resize(image, (3, 20), data_format="channels_first") + self.assertEqual(resized_image.shape, (3, 3, 20)) + def test_resize(self): image = np.random.randint(0, 256, (3, 224, 224))