From db5968355852cef60a96503328535cf8004928bb Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Wed, 21 Sep 2022 17:19:30 +0100 Subject: [PATCH] Revert "Type cast before normalize vit" --- .../models/vit/feature_extraction_vit.py | 23 ++++-------- .../models/vit/test_feature_extraction_vit.py | 35 ------------------- 2 files changed, 7 insertions(+), 51 deletions(-) diff --git a/src/transformers/models/vit/feature_extraction_vit.py b/src/transformers/models/vit/feature_extraction_vit.py index 0a351824564c..29c0fa3fc4f6 100644 --- a/src/transformers/models/vit/feature_extraction_vit.py +++ b/src/transformers/models/vit/feature_extraction_vit.py @@ -98,15 +98,13 @@ def __call__( tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. - return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `None`): - If set, will return a tensor of a particular framework. + return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): + If set, will return tensors of a particular framework. Acceptable values are: - Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` object. - - `'pt'`: Return PyTorch `torch.Tensor` object. - - `'np'`: Return NumPy `np.ndarray` object. - - `'jax'`: Return JAX `jnp.ndarray` object. - - None: Return list of `np.ndarray` objects. + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -141,15 +139,8 @@ def __call__( # transformations (resizing + normalization) if self.do_resize and self.size is not None: images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images] - - # if do_normalize=False, the casting to a numpy array won't happen, so we need to do it here - make_channel_first = True if isinstance(images[0], Image.Image) else images[0].shape[-1] in (1, 3) - images = [self.to_numpy_array(image, rescale=False, channel_first=make_channel_first) for image in images] - if self.do_normalize: - images = [ - self.normalize(image=image, mean=self.image_mean, std=self.image_std, rescale=True) for image in images - ] + images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] # return as BatchFeature data = {"pixel_values": images} diff --git a/tests/models/vit/test_feature_extraction_vit.py b/tests/models/vit/test_feature_extraction_vit.py index a957965c0894..2daf6452fff5 100644 --- a/tests/models/vit/test_feature_extraction_vit.py +++ b/tests/models/vit/test_feature_extraction_vit.py @@ -18,7 +18,6 @@ import numpy as np -from parameterized import parameterized from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_torch_available, is_vision_available @@ -190,37 +189,3 @@ def test_call_pytorch(self): self.feature_extract_tester.size, ), ) - - @parameterized.expand( - [ - ("do_resize_True_do_normalize_True", True, True), - ("do_resize_True_do_normalize_False", True, False), - ("do_resize_False_do_normalize_True", False, True), - ("do_resize_False_do_normalize_False", False, False), - ] - ) - def test_call_flags(self, _, do_resize, do_normalize): - # Initialize feature_extractor - feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) - feature_extractor.do_resize = do_resize - feature_extractor.do_normalize = do_normalize - # create random PIL images - image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) - - if do_resize: - expected_shapes = [ - ( - self.feature_extract_tester.num_channels, - self.feature_extract_tester.size, - self.feature_extract_tester.size, - ) - for _ in range(self.feature_extract_tester.batch_size) - ] - else: - expected_shapes = [(3, *x.size[::-1]) for x in image_inputs] - - pixel_values = feature_extractor(image_inputs, return_tensors=None)["pixel_values"] - self.assertEqual(len(pixel_values), self.feature_extract_tester.batch_size) - for idx, image in enumerate(pixel_values): - self.assertEqual(image.shape, expected_shapes[idx]) - self.assertIsInstance(image, np.ndarray)