Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions src/transformers/image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@
import jax.numpy as jnp


def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDimension, str]) -> np.ndarray:
def to_channel_dimension_format(
image: np.ndarray,
channel_dim: Union[ChannelDimension, str],
input_channel_dim: Optional[Union[ChannelDimension, str]] = None,
) -> np.ndarray:
"""
Converts `image` to the channel dimension format specified by `channel_dim`.

Expand All @@ -64,9 +68,11 @@ def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDim
if not isinstance(image, np.ndarray):
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")

current_channel_dim = infer_channel_dimension_format(image)
if input_channel_dim is None:
input_channel_dim = infer_channel_dimension_format(image)

target_channel_dim = ChannelDimension(channel_dim)
if current_channel_dim == target_channel_dim:
if input_channel_dim == target_channel_dim:
return image

if target_channel_dim == ChannelDimension.FIRST:
Expand Down Expand Up @@ -152,6 +158,7 @@ def to_pil_image(
return PIL.Image.fromarray(image)


# Logic adapted from torchvision resizing logic: https://github.com/pytorch/vision/blob/511924c1ced4ce0461197e5caa64ce5b9e558aab/torchvision/transforms/functional.py#L366
def get_resize_output_image_size(
input_image: np.ndarray,
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
Expand Down Expand Up @@ -202,9 +209,6 @@ def get_resize_output_image_size(
short, long = (width, height) if width <= height else (height, width)
requested_new_short = size

if short == requested_new_short:
return (height, width)

new_short, new_long = requested_new_short, int(requested_new_short * long / short)

if max_size is not None:
Expand Down Expand Up @@ -271,7 +275,10 @@ def resize(
# If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
# so we need to add it back if necessary.
resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
resized_image = to_channel_dimension_format(resized_image, data_format)
# The image is always in channels last format after converting from a PIL image
resized_image = to_channel_dimension_format(
resized_image, data_format, input_channel_dim=ChannelDimension.LAST
)
return resized_image


Expand Down
28 changes: 23 additions & 5 deletions src/transformers/models/donut/image_processing_donut.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def thumbnail(
**kwargs
) -> np.ndarray:
"""
Resize the image to the specified size using thumbnail method.
Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any
corresponding dimension of the specified size.

Args:
image (`np.ndarray`):
Expand All @@ -222,8 +223,24 @@ def thumbnail(
data_format (`Optional[Union[str, ChannelDimension]]`, *optional*):
The data format of the output image. If unset, the same format as the input image is used.
"""
output_size = (size["height"], size["width"])
return resize(image, size=output_size, resample=resample, reducing_gap=2.0, data_format=data_format, **kwargs)
input_height, input_width = get_image_size(image)
output_height, output_width = size["height"], size["width"]

# We always resize to the smallest of either the input or output size.
height = min(input_height, output_height)
width = min(input_width, output_width)

if height == input_height and width == input_width:
return image

if input_height > input_width:
width = int(input_width * height / input_height)
elif input_width > input_height:
height = int(input_height * width / input_width)

return resize(
image, size=(height, width), resample=resample, reducing_gap=2.0, data_format=data_format, **kwargs
)

def resize(
self,
Expand All @@ -250,7 +267,8 @@ def resize(
size = get_size_dict(size)
shortest_edge = min(size["height"], size["width"])
output_size = get_resize_output_image_size(image, size=shortest_edge, default_to_square=False)
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
resized_image = resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
return resized_image

def rescale(
self,
Expand Down Expand Up @@ -403,7 +421,7 @@ def preprocess(
images = [to_numpy_array(image) for image in images]

if do_align_long_axis:
images = [self.align_long_axis(image) for image in images]
images = [self.align_long_axis(image, size=size) for image in images]

if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ def test_inference_docvqa(self):
expected_shape = torch.Size([1, 1, 57532])
self.assertEqual(outputs.logits.shape, expected_shape)

expected_slice = torch.tensor([24.2731, -6.4522, 32.4130]).to(torch_device)
expected_slice = torch.tensor([24.3873, -6.4491, 32.5394]).to(torch_device)
self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4))

# step 2: generation
Expand Down Expand Up @@ -872,7 +872,7 @@ def test_inference_docvqa(self):
self.assertEqual(len(outputs.scores), 11)
self.assertTrue(
torch.allclose(
outputs.scores[0][0, :3], torch.tensor([5.3153, -3.5276, 13.4781], device=torch_device), atol=1e-4
outputs.scores[0][0, :3], torch.tensor([5.6019, -3.5070, 13.7123], device=torch_device), atol=1e-4
)
)

Expand Down
19 changes: 19 additions & 0 deletions tests/test_image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,25 @@ def test_get_resize_output_image_size(self):
image = np.random.randint(0, 256, (3, 50, 40))
self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17))

# Test correct channel dimension is returned if output size if height == 3
# Defaults to input format - channels first
image = np.random.randint(0, 256, (3, 18, 97))
resized_image = resize(image, (3, 20))
self.assertEqual(resized_image.shape, (3, 3, 20))

# Defaults to input format - channels last
image = np.random.randint(0, 256, (18, 97, 3))
resized_image = resize(image, (3, 20))
self.assertEqual(resized_image.shape, (3, 20, 3))

image = np.random.randint(0, 256, (3, 18, 97))
resized_image = resize(image, (3, 20), data_format="channels_last")
self.assertEqual(resized_image.shape, (3, 20, 3))

image = np.random.randint(0, 256, (18, 97, 3))
resized_image = resize(image, (3, 20), data_format="channels_first")
self.assertEqual(resized_image.shape, (3, 3, 20))

def test_resize(self):
image = np.random.randint(0, 256, (3, 224, 224))

Expand Down