diff --git a/src/transformers/models/detr/feature_extraction_detr.py b/src/transformers/models/detr/feature_extraction_detr.py index 91e406c71fc9..b38dbace40de 100644 --- a/src/transformers/models/detr/feature_extraction_detr.py +++ b/src/transformers/models/detr/feature_extraction_detr.py @@ -378,14 +378,14 @@ def get_size(image_size, size, max_size=None): return rescaled_image, target - def _normalize(self, image, mean, std, target=None): + def _normalize(self, image, mean, std, target=None, rescale=False): """ Normalize the image with a certain mean and std. If given, also normalize the target bounding boxes based on the size of the image. """ - image = self.normalize(image, mean=mean, std=std) + image = self.normalize(image, mean=mean, std=std, rescale=rescale) if target is None: return image, None @@ -455,9 +455,15 @@ def __call__( - 1 for pixels that are real (i.e. **not masked**), - 0 for pixels that are padding (i.e. **masked**). - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor` - objects. + return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `None`): + If set, will return a tensor of a particular framework. + + Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + - None: Return list of `np.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -567,17 +573,22 @@ def __call__( for idx, image in enumerate(images): images[idx] = self._resize(image=image, target=None, size=self.size, max_size=self.max_size)[0] + # if do_normalize=False, the casting to a numpy array won't happen, so we need to do it here + make_channel_first = True if isinstance(images[0], Image.Image) else images[0].shape[-1] in (1, 3) + images = [self.to_numpy_array(image, rescale=False, channel_first=make_channel_first) for image in images] + if self.do_normalize: if annotations is not None: for idx, (image, target) in enumerate(zip(images, annotations)): image, target = self._normalize( - image=image, mean=self.image_mean, std=self.image_std, target=target + image=image, mean=self.image_mean, std=self.image_std, target=target, rescale=True ) images[idx] = image annotations[idx] = target else: images = [ - self._normalize(image=image, mean=self.image_mean, std=self.image_std)[0] for image in images + self._normalize(image=image, mean=self.image_mean, std=self.image_std, rescale=True)[0] + for image in images ] if pad_and_return_pixel_mask: diff --git a/tests/models/detr/test_feature_extraction_detr.py b/tests/models/detr/test_feature_extraction_detr.py index 58bde80fbbb1..5f3a93c8c14b 100644 --- a/tests/models/detr/test_feature_extraction_detr.py +++ b/tests/models/detr/test_feature_extraction_detr.py @@ -20,6 +20,7 @@ import numpy as np +from parameterized import parameterized from transformers.testing_utils import require_torch, require_vision, slow from transformers.utils import is_torch_available, is_vision_available @@ -336,3 +337,43 @@ def test_call_pytorch_with_coco_panoptic_annotations(self): # verify size expected_size = torch.tensor([800, 1066]) assert torch.allclose(encoding["labels"][0]["size"], expected_size) + + @parameterized.expand( + [ + ("do_resize_True_do_normalize_True_pad_True", True, True, True), + ("do_resize_True_do_normalize_False_pad_True", True, False, True), + ("do_resize_False_do_normalize_False_pad_True", False, False, True), + ("do_resize_False_do_normalize_True_pad_True", False, True, True), + ("do_resize_True_do_normalize_True_pad_False", True, True, False), + ("do_resize_True_do_normalize_False_pad_False", True, False, False), + ("do_resize_False_do_normalize_False_pad_False", False, False, False), + ("do_resize_False_do_normalize_True_pad_False", False, True, False), + ] + ) + def test_call_flags(self, _, do_resize, do_normalize, pad): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + feature_extractor.do_resize = do_resize + feature_extractor.do_normalize = do_normalize + # create random PIL images + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) + + expected_shapes = [(3, *x.size[::-1]) for x in image_inputs] + if do_resize: + expected_shapes = [ + ( + self.feature_extract_tester.num_channels, + *self.feature_extract_tester.get_expected_values([image], batched=False), + ) + for image in image_inputs + ] + if pad: + expected_shapes = [tuple(max(x) for x in zip(*expected_shapes))] * len(image_inputs) + + pixel_values = feature_extractor(image_inputs, pad_and_return_pixel_mask=pad, return_tensors=None)[ + "pixel_values" + ] + self.assertEqual(len(pixel_values), self.feature_extract_tester.batch_size) + for idx, image in enumerate(pixel_values): + self.assertEqual(image.shape, expected_shapes[idx]) + self.assertIsInstance(image, np.ndarray)