Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/transformers/image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,11 @@ def rescale(
if not isinstance(image, np.ndarray):
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")

image = image.astype(np.float32)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image should be cast before scaling to ensure precision isn't lost when rescaling with a float value.

This has the potential to slightly change the resulting pixel values for models. To check this doesn't affect our current models' outputs I ran integration tests for swin, vit, detr, clip and resnet.


rescaled_image = image * scale
if data_format is not None:
rescaled_image = to_channel_dimension_format(rescaled_image, data_format)
rescaled_image = rescaled_image.astype(dtype)
return rescaled_image


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,13 @@ def rescale(
**kwargs,
):
"""
Rescale an image by a scale factor. image = image * scale.
Rescale an image by a scale factor.

If offset is True, the image is rescaled between [-1, 1].
image = image * scale * 2 - 1

If offset is False, the image is rescaled between [0, 1].
image = image * scale

Args:
image (`np.ndarray`):
Expand All @@ -165,13 +171,12 @@ def rescale(
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
scale = scale * 2 if offset else scale
rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs)

if offset:
rescaled_image = (image - 127.5) * scale
if data_format is not None:
rescaled_image = to_channel_dimension_format(rescaled_image, data_format)
rescaled_image = rescaled_image.astype(np.float32)
else:
rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs)
rescaled_image = rescaled_image - 1

return rescaled_image

def preprocess(
Expand Down
19 changes: 13 additions & 6 deletions src/transformers/models/vivit/image_processing_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def resize(
raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)

# Copied from transformers.models.efficientnet.image_processing_efficientnet.EfficientNetImageProcessor.rescale
def rescale(
self,
image: np.ndarray,
Expand All @@ -178,23 +179,29 @@ def rescale(
"""
Rescale an image by a scale factor.

If offset is `True`, image scaled between [-1, 1]: image = (image - 127.5) * scale. If offset is `False`, image
scaled between [0, 1]: image = image * scale
If offset is True, the image is rescaled between [-1, 1].
image = image * scale * 2 - 1

If offset is False, the image is rescaled between [0, 1].
image = image * scale

Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
offset (`bool`, *optional*):
offset (`bool`, *optional*):
Whether to scale the image in both negative and positive directions.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
image = image.astype(np.float32)
scale = scale * 2 if offset else scale
rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs)

if offset:
image = image - (scale / 2)
return rescale(image, scale=scale, data_format=data_format, **kwargs)
rescaled_image = rescaled_image - 1

return rescaled_image

def _preprocess_image(
self,
Expand Down
14 changes: 14 additions & 0 deletions tests/models/efficientnet/test_image_processing_efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,17 @@ def test_call_pytorch(self):
self.image_processor_tester.size["width"],
),
)

def test_rescale(self):
# EfficientNet optionally rescales between -1 and 1 instead of the usual 0 and 1
image = np.arange(0, 256, 1, dtype=np.uint8).reshape(1, 8, 32)

image_processor = self.image_processing_class(**self.image_processor_dict)

rescaled_image = image_processor.rescale(image, scale=1 / 255)
expected_image = image.astype(np.float32) * (2 / 255.0) - 1
self.assertTrue(np.allclose(rescaled_image, expected_image))

rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False)
expected_image = image.astype(np.float32) / 255.0
self.assertTrue(np.allclose(rescaled_image, expected_image))
14 changes: 14 additions & 0 deletions tests/models/vivit/test_image_processing_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,17 @@ def test_call_pytorch(self):
self.image_processor_tester.crop_size["width"],
),
)

def test_rescale(self):
# ViVit optionally rescales between -1 and 1 instead of the usual 0 and 1
image = np.arange(0, 256, 1, dtype=np.uint8).reshape(1, 8, 32)

image_processor = self.image_processing_class(**self.image_processor_dict)

rescaled_image = image_processor.rescale(image, scale=1 / 255)
expected_image = image.astype(np.float32) * (2 / 255.0) - 1
self.assertTrue(np.allclose(rescaled_image, expected_image))

rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False)
expected_image = image.astype(np.float32) / 255.0
self.assertTrue(np.allclose(rescaled_image, expected_image))
2 changes: 1 addition & 1 deletion tests/models/vivit/test_modeling_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,6 @@ def test_inference_for_video_classification(self):
self.assertEqual(outputs.logits.shape, expected_shape)

# taken from original model
expected_slice = torch.tensor([-1.0543, 2.0764, -0.2104, 0.4439, -0.9658]).to(torch_device)
expected_slice = torch.tensor([-0.9498, 2.7971, -1.4049, 0.1024, -1.8353]).to(torch_device)

self.assertTrue(torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4))