diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index 63585d5e038b..71afaaf268da 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -110,10 +110,11 @@ def rescale( if not isinstance(image, np.ndarray): raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}") + image = image.astype(dtype) + rescaled_image = image * scale if data_format is not None: rescaled_image = to_channel_dimension_format(rescaled_image, data_format) - rescaled_image = rescaled_image.astype(dtype) return rescaled_image diff --git a/src/transformers/models/efficientnet/image_processing_efficientnet.py b/src/transformers/models/efficientnet/image_processing_efficientnet.py index eaefb9c10150..8873a8006997 100644 --- a/src/transformers/models/efficientnet/image_processing_efficientnet.py +++ b/src/transformers/models/efficientnet/image_processing_efficientnet.py @@ -153,7 +153,13 @@ def rescale( **kwargs, ): """ - Rescale an image by a scale factor. image = image * scale. + Rescale an image by a scale factor. + + If offset is True, the image is rescaled between [-1, 1]. + image = image * scale * 2 - 1 + + If offset is False, the image is rescaled between [0, 1]. + image = image * scale Args: image (`np.ndarray`): @@ -165,13 +171,12 @@ def rescale( data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. """ + scale = scale * 2 if offset else scale + rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs) + if offset: - rescaled_image = (image - 127.5) * scale - if data_format is not None: - rescaled_image = to_channel_dimension_format(rescaled_image, data_format) - rescaled_image = rescaled_image.astype(np.float32) - else: - rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs) + rescaled_image = rescaled_image - 1 + return rescaled_image def preprocess( diff --git a/src/transformers/models/vivit/image_processing_vivit.py b/src/transformers/models/vivit/image_processing_vivit.py index 2aa7a911fed1..41666e99999f 100644 --- a/src/transformers/models/vivit/image_processing_vivit.py +++ b/src/transformers/models/vivit/image_processing_vivit.py @@ -167,6 +167,7 @@ def resize( raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}") return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) + # Copied from transformers.models.efficientnet.image_processing_efficientnet.EfficientNetImageProcessor.rescale def rescale( self, image: np.ndarray, @@ -178,23 +179,29 @@ def rescale( """ Rescale an image by a scale factor. - If offset is `True`, image scaled between [-1, 1]: image = (image - 127.5) * scale. If offset is `False`, image - scaled between [0, 1]: image = image * scale + If offset is True, the image is rescaled between [-1, 1]. + image = image * scale * 2 - 1 + + If offset is False, the image is rescaled between [0, 1]. + image = image * scale Args: image (`np.ndarray`): Image to rescale. scale (`int` or `float`): Scale to apply to the image. - offset (`bool`, *optional*): + offset (`bool`, *optional*): Whether to scale the image in both negative and positive directions. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. """ - image = image.astype(np.float32) + scale = scale * 2 if offset else scale + rescaled_image = rescale(image, scale=scale, data_format=data_format, **kwargs) + if offset: - image = image - (scale / 2) - return rescale(image, scale=scale, data_format=data_format, **kwargs) + rescaled_image = rescaled_image - 1 + + return rescaled_image def _preprocess_image( self, diff --git a/tests/models/efficientnet/test_image_processing_efficientnet.py b/tests/models/efficientnet/test_image_processing_efficientnet.py index 8e4fad3b0840..11aee2d01c5c 100644 --- a/tests/models/efficientnet/test_image_processing_efficientnet.py +++ b/tests/models/efficientnet/test_image_processing_efficientnet.py @@ -193,3 +193,17 @@ def test_call_pytorch(self): self.image_processor_tester.size["width"], ), ) + + def test_rescale(self): + # EfficientNet optionally rescales between -1 and 1 instead of the usual 0 and 1 + image = np.arange(0, 256, 1, dtype=np.uint8).reshape(1, 8, 32) + + image_processor = self.image_processing_class(**self.image_processor_dict) + + rescaled_image = image_processor.rescale(image, scale=1 / 255) + expected_image = image.astype(np.float32) * (2 / 255.0) - 1 + self.assertTrue(np.allclose(rescaled_image, expected_image)) + + rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False) + expected_image = image.astype(np.float32) / 255.0 + self.assertTrue(np.allclose(rescaled_image, expected_image)) diff --git a/tests/models/vivit/test_image_processing_vivit.py b/tests/models/vivit/test_image_processing_vivit.py index f33553db0f25..75a3e0264c0a 100644 --- a/tests/models/vivit/test_image_processing_vivit.py +++ b/tests/models/vivit/test_image_processing_vivit.py @@ -212,3 +212,17 @@ def test_call_pytorch(self): self.image_processor_tester.crop_size["width"], ), ) + + def test_rescale(self): + # ViVit optionally rescales between -1 and 1 instead of the usual 0 and 1 + image = np.arange(0, 256, 1, dtype=np.uint8).reshape(1, 8, 32) + + image_processor = self.image_processing_class(**self.image_processor_dict) + + rescaled_image = image_processor.rescale(image, scale=1 / 255) + expected_image = image.astype(np.float32) * (2 / 255.0) - 1 + self.assertTrue(np.allclose(rescaled_image, expected_image)) + + rescaled_image = image_processor.rescale(image, scale=1 / 255, offset=False) + expected_image = image.astype(np.float32) / 255.0 + self.assertTrue(np.allclose(rescaled_image, expected_image)) diff --git a/tests/models/vivit/test_modeling_vivit.py b/tests/models/vivit/test_modeling_vivit.py index 43db8bad7bde..ed032e4bdd32 100644 --- a/tests/models/vivit/test_modeling_vivit.py +++ b/tests/models/vivit/test_modeling_vivit.py @@ -345,6 +345,6 @@ def test_inference_for_video_classification(self): self.assertEqual(outputs.logits.shape, expected_shape) # taken from original model - expected_slice = torch.tensor([-1.0543, 2.0764, -0.2104, 0.4439, -0.9658]).to(torch_device) + expected_slice = torch.tensor([-0.9498, 2.7971, -1.4049, 0.1024, -1.8353]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4)) diff --git a/tests/pipelines/test_pipelines_zero_shot_image_classification.py b/tests/pipelines/test_pipelines_zero_shot_image_classification.py index fbbfc78cae3c..197019f42e7b 100644 --- a/tests/pipelines/test_pipelines_zero_shot_image_classification.py +++ b/tests/pipelines/test_pipelines_zero_shot_image_classification.py @@ -85,6 +85,7 @@ def test_small_model_pt(self): [ [{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "b"}, {"score": 0.333, "label": "c"}], [{"score": 0.333, "label": "a"}, {"score": 0.333, "label": "c"}, {"score": 0.333, "label": "b"}], + [{"score": 0.333, "label": "b"}, {"score": 0.333, "label": "a"}, {"score": 0.333, "label": "c"}], ], )