diff --git a/src/transformers/models/segformer/feature_extraction_segformer.py b/src/transformers/models/segformer/feature_extraction_segformer.py index 60698030e99d..0a9ae01ef121 100644 --- a/src/transformers/models/segformer/feature_extraction_segformer.py +++ b/src/transformers/models/segformer/feature_extraction_segformer.py @@ -112,15 +112,13 @@ def __call__( segmentation_maps (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*): Optionally, the corresponding semantic segmentation maps with the pixel-wise annotations. - return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `None`): - If set, will return a tensor of a particular framework. + return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): + If set, will return tensors of a particular framework. Acceptable values are: - Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` object. - - `'pt'`: Return PyTorch `torch.Tensor` object. - - `'np'`: Return NumPy `np.ndarray` object. - - `'jax'`: Return JAX `jnp.ndarray` object. - - None: Return list of `np.ndarray` objects. + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -195,14 +193,8 @@ def __call__( self.resize(map, size=self.size, resample=Image.NEAREST) for map in segmentation_maps ] - # if do_normalize=False, the casting to a numpy array won't happen, so we need to do it here - make_channel_first = True if isinstance(images[0], Image.Image) else images[0].shape[-1] in (1, 3) - images = [self.to_numpy_array(image, rescale=False, channel_first=make_channel_first) for image in images] - if self.do_normalize: - images = [ - self.normalize(image=image, mean=self.image_mean, std=self.image_std, rescale=True) for image in images - ] + images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] # return as BatchFeature data = {"pixel_values": images} diff --git a/tests/models/segformer/test_feature_extraction_segformer.py b/tests/models/segformer/test_feature_extraction_segformer.py index 2efa592d9815..75083012d875 100644 --- a/tests/models/segformer/test_feature_extraction_segformer.py +++ b/tests/models/segformer/test_feature_extraction_segformer.py @@ -19,7 +19,6 @@ import numpy as np from datasets import load_dataset -from parameterized import parameterized from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_torch_available, is_vision_available @@ -334,40 +333,3 @@ def test_reduce_labels(self): encoding = feature_extractor(image, map, return_tensors="pt") self.assertTrue(encoding["labels"].min().item() >= 0) self.assertTrue(encoding["labels"].max().item() <= 255) - - @parameterized.expand( - [ - ("do_resize_True_do_normalize_True", True, True), - ("do_resize_True_do_normalize_False", True, False), - ("do_resize_True_do_normalize_True", True, True), - ("do_resize_True_do_normalize_False", True, False), - ("do_resize_False_do_normalize_True", False, True), - ("do_resize_False_do_normalize_False", False, False), - ("do_resize_False_do_normalize_True", False, True), - ("do_resize_False_do_normalize_False", False, False), - ] - ) - def test_call_flags(self, _, do_resize, do_normalize): - # Initialize feature_extractor - feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) - feature_extractor.do_resize = do_resize - feature_extractor.do_normalize = do_normalize - # create random PIL images - image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) - - expected_shapes = [(3, *x.size[::-1]) for x in image_inputs] - if do_resize: - expected_shapes = [ - ( - self.feature_extract_tester.num_channels, - self.feature_extract_tester.size, - self.feature_extract_tester.size, - ) - for _ in range(self.feature_extract_tester.batch_size) - ] - - pixel_values = feature_extractor(image_inputs, return_tensors=None)["pixel_values"] - self.assertEqual(len(pixel_values), self.feature_extract_tester.batch_size) - for idx, image in enumerate(pixel_values): - self.assertEqual(image.shape, expected_shapes[idx]) - self.assertIsInstance(image, np.ndarray)