diff --git a/src/transformers/models/mobilevit/feature_extraction_mobilevit.py b/src/transformers/models/mobilevit/feature_extraction_mobilevit.py index 51e022b809c9..274c83fee4f3 100644 --- a/src/transformers/models/mobilevit/feature_extraction_mobilevit.py +++ b/src/transformers/models/mobilevit/feature_extraction_mobilevit.py @@ -93,13 +93,15 @@ def __call__( tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. - return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. + return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `None`): + If set, will return a tensor of a particular framework. + + Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` object. + - `'pt'`: Return PyTorch `torch.Tensor` object. + - `'np'`: Return NumPy `np.ndarray` object. + - `'jax'`: Return JAX `jnp.ndarray` object. + - None: Return list of `np.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/tests/models/mobilevit/test_feature_extraction_mobilevit.py b/tests/models/mobilevit/test_feature_extraction_mobilevit.py index f13267c541c9..54b13947b824 100644 --- a/tests/models/mobilevit/test_feature_extraction_mobilevit.py +++ b/tests/models/mobilevit/test_feature_extraction_mobilevit.py @@ -18,6 +18,7 @@ import numpy as np +from parameterized import parameterized from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_torch_available, is_vision_available @@ -189,3 +190,55 @@ def test_call_pytorch(self): self.feature_extract_tester.crop_size, ), ) + + @parameterized.expand( + [ + ("do_resize_True_do_center_crop_True_do_flip_channel_order_True", True, True, True), + ("do_resize_True_do_center_crop_True_do_flip_channel_order_False", True, True, False), + ("do_resize_True_do_center_crop_False_do_flip_channel_order_True", True, False, True), + ("do_resize_True_do_center_crop_False_do_flip_channel_order_False", True, False, False), + ("do_resize_False_do_center_crop_True_do_flip_channel_order_True", False, True, True), + ("do_resize_False_do_center_crop_True_do_flip_channel_order_False", False, True, False), + ("do_resize_False_do_center_crop_False_do_flip_channel_order_True", False, False, True), + ("do_resize_False_do_center_crop_False_do_flip_channel_order_False", False, False, False), + ] + ) + def test_call_flags(self, _, do_resize, do_center_crop, do_flip_channel_order): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + feature_extractor.do_center_crop = do_center_crop + feature_extractor.do_resize = do_resize + feature_extractor.do_flip_channel_order = do_flip_channel_order + # create random PIL images + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) + + expected_shapes = [(3, *x.size[::-1]) for x in image_inputs] + if do_resize: + # Same size logic inside resized + resized_shapes = [] + for shape in expected_shapes: + c, h, w = shape + short, long = (w, h) if w <= h else (h, w) + min_size = self.feature_extract_tester.size + if short == min_size: + resized_shapes.append((c, h, w)) + else: + short, long = min_size, int(long * min_size / short) + resized_shape = (c, long, short) if w <= h else (c, short, long) + resized_shapes.append(resized_shape) + expected_shapes = resized_shapes + if do_center_crop: + expected_shapes = [ + ( + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ) + for _ in range(self.feature_extract_tester.batch_size) + ] + + pixel_values = feature_extractor(image_inputs, return_tensors=None)["pixel_values"] + self.assertEqual(len(pixel_values), self.feature_extract_tester.batch_size) + for idx, image in enumerate(pixel_values): + self.assertEqual(image.shape, expected_shapes[idx]) + self.assertIsInstance(image, np.ndarray)