diff --git a/src/transformers/models/glpn/feature_extraction_glpn.py b/src/transformers/models/glpn/feature_extraction_glpn.py index 2694d56b898b..e1defdbea34e 100644 --- a/src/transformers/models/glpn/feature_extraction_glpn.py +++ b/src/transformers/models/glpn/feature_extraction_glpn.py @@ -86,13 +86,15 @@ def __call__( tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. - return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): - If set, will return tensors of a particular framework. Acceptable values are: + return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `None`): + If set, will return a tensor of a particular framework. - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. + Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` object. + - `'pt'`: Return PyTorch `torch.Tensor` object. + - `'np'`: Return NumPy `np.ndarray` object. + - `'jax'`: Return JAX `jnp.ndarray` object. + - None: Return list of `np.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -129,8 +131,13 @@ def __call__( images = [ self._resize(image=image, size_divisor=self.size_divisor, resample=self.resample) for image in images ] + + # if do_rescale=False, the casting to a numpy array won't happen, so we need to do it here + make_channel_first = True if isinstance(images[0], Image.Image) else images[0].shape[-1] in (1, 3) + images = [self.to_numpy_array(image, rescale=False, channel_first=make_channel_first) for image in images] + if self.do_rescale: - images = [self.to_numpy_array(image=image) for image in images] + images = [self.rescale(image=image.astype(np.float32), scale=1 / 255.0) for image in images] # return as BatchFeature data = {"pixel_values": images} diff --git a/tests/models/glpn/test_feature_extraction_glpn.py b/tests/models/glpn/test_feature_extraction_glpn.py index 4e7f2bdf5c78..8fcae72bd513 100644 --- a/tests/models/glpn/test_feature_extraction_glpn.py +++ b/tests/models/glpn/test_feature_extraction_glpn.py @@ -18,6 +18,7 @@ import numpy as np +from parameterized import parameterized from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_torch_available, is_vision_available @@ -125,3 +126,42 @@ def test_call_pytorch(self): encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values self.assertTrue(encoded_images.shape[-1] % self.feature_extract_tester.size_divisor == 0) self.assertTrue(encoded_images.shape[-2] % self.feature_extract_tester.size_divisor == 0) + + @parameterized.expand( + [ + ("do_resize_True_do_rescale_True", True, True), + ("do_resize_True_do_rescale_False", True, False), + ("do_resize_True_do_rescale_True", True, True), + ("do_resize_True_do_rescale_False", True, False), + ("do_resize_False_do_rescale_True", False, True), + ("do_resize_False_do_rescale_False", False, False), + ("do_resize_False_do_rescale_True", False, True), + ("do_resize_False_do_rescale_False", False, False), + ] + ) + def test_call_flags(self, _, do_resize, do_rescale): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + feature_extractor.do_resize = do_resize + feature_extractor.do_normalize = do_rescale + # create random PIL images + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True) + + # expected_shapes = [(3, *x.size[::-1]) for x in image_inputs] + expected_shapes = [x.shape for x in image_inputs] + if do_resize: + size_divisor = self.feature_extract_tester.size_divisor + expected_shapes = [ + ( + self.feature_extract_tester.num_channels, + (shape[1] // size_divisor) * size_divisor, + (shape[2] // size_divisor) * size_divisor, + ) + for shape in expected_shapes + ] + + pixel_values = feature_extractor(image_inputs, return_tensors=None)["pixel_values"] + self.assertEqual(len(pixel_values), self.feature_extract_tester.batch_size) + for idx, image in enumerate(pixel_values): + self.assertEqual(image.shape, expected_shapes[idx]) + self.assertIsInstance(image, np.ndarray)