From 904a21bdeff96a26a6ce76112192b1b6277c04a7 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Tue, 14 Mar 2023 13:56:12 +0000 Subject: [PATCH 1/4] Don't rescale if in and in range 0-255 --- src/transformers/image_transforms.py | 10 ++++++---- tests/test_image_transforms.py | 29 +++++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index 16016d97042c..873a7dedd9b4 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -156,12 +156,14 @@ def to_pil_image( # If there is a single channel, we squeeze it, as otherwise PIL can't handle it. image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image - # PIL.Image can only store uint8 values, so we rescale the image to be between 0 and 255 if needed. + # PIL.Image can only store uint8 values so we rescale the image to be between 0 and 255 if needed. if do_rescale is None: - if np.all(0 <= image) and np.all(image <= 1): - do_rescale = True - elif np.allclose(image, image.astype(int)): + if image.dtype == np.uint8: + do_rescale = False + elif np.allclose(image, image.astype(int)) and np.all(0 <= image) and np.all(image <= 255): do_rescale = False + elif np.all(0 <= image) and np.all(image <= 1): + do_rescale = True else: raise ValueError( "The image to be converted to a PIL image contains values outside the range [0, 1], " diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index 0efefc7c8fbb..aa92823fa26d 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -101,6 +101,33 @@ def test_to_pil_image_from_float(self, name, image_shape, dtype): with self.assertRaises(ValueError): to_pil_image(image) + # Make sure binary mask remains a binary mask + image = np.random.randint(0, 2, image_shape).astype(np.uint8) + pil_image = to_pil_image(image) + self.assertIsInstance(pil_image, PIL.Image.Image) + self.assertEqual(pil_image.size, (5, 4)) + + @require_vision + def test_to_pil_image_from_mask(self): + # Make sure binary mask remains a binary mask + image = np.random.randint(0, 2, (3, 4, 5)).astype(np.uint8) + pil_image = to_pil_image(image) + self.assertIsInstance(pil_image, PIL.Image.Image) + self.assertEqual(pil_image.size, (5, 4)) + + np_img = np.asarray(pil_image) + self.assertTrue(np_img.min() == 0) + self.assertTrue(np_img.max() == 1) + + image = np.random.randint(0, 2, (3, 4, 5)).astype(np.float32) + pil_image = to_pil_image(image) + self.assertIsInstance(pil_image, PIL.Image.Image) + self.assertEqual(pil_image.size, (5, 4)) + + np_img = np.asarray(pil_image) + self.assertTrue(np_img.min() == 0) + self.assertTrue(np_img.max() == 1) + @require_tf def test_to_pil_image_from_tensorflow(self): # channels_first @@ -222,7 +249,7 @@ def test_resize(self): self.assertIsInstance(resized_image, np.ndarray) self.assertEqual(resized_image.shape, (30, 40, 3)) - # Check PIL.Image.Image is return if return_numpy=False + # Check PIL.Image.Image is returned if return_numpy=False resized_image = resize(image, (30, 40), return_numpy=False) self.assertIsInstance(resized_image, PIL.Image.Image) # PIL size is in (width, height) order From 91871bc598cba7a338bd70e3755209962e036378 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Tue, 14 Mar 2023 14:03:52 +0000 Subject: [PATCH 2/4] Raise value error if int values too large --- src/transformers/image_transforms.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index 873a7dedd9b4..8f3fac73dd5f 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -160,8 +160,14 @@ def to_pil_image( if do_rescale is None: if image.dtype == np.uint8: do_rescale = False - elif np.allclose(image, image.astype(int)) and np.all(0 <= image) and np.all(image <= 255): - do_rescale = False + elif np.allclose(image, image.astype(int)): + if np.all(0 <= image) and np.all(image <= 255): + do_rescale = False + else: + raise ValueError( + "The image to be converted to a PIL image contains values outside the range [0, 255], " + f"got [{image.min()}, {image.max()}] which cannot be converted to uint8." + ) elif np.all(0 <= image) and np.all(image <= 1): do_rescale = True else: From 06326eb0f3cb61a6df670e0808767546c1c531bc Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Tue, 14 Mar 2023 14:44:07 +0000 Subject: [PATCH 3/4] Update tests/test_image_transforms.py --- tests/test_image_transforms.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index aa92823fa26d..ff1383be41d6 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -101,11 +101,6 @@ def test_to_pil_image_from_float(self, name, image_shape, dtype): with self.assertRaises(ValueError): to_pil_image(image) - # Make sure binary mask remains a binary mask - image = np.random.randint(0, 2, image_shape).astype(np.uint8) - pil_image = to_pil_image(image) - self.assertIsInstance(pil_image, PIL.Image.Image) - self.assertEqual(pil_image.size, (5, 4)) @require_vision def test_to_pil_image_from_mask(self): From 5e9befdad4760d764a35123daacd1cb1d1afd751 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Tue, 14 Mar 2023 14:44:34 +0000 Subject: [PATCH 4/4] Update tests/test_image_transforms.py --- tests/test_image_transforms.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index ff1383be41d6..79580e0876e4 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -101,7 +101,6 @@ def test_to_pil_image_from_float(self, name, image_shape, dtype): with self.assertRaises(ValueError): to_pil_image(image) - @require_vision def test_to_pil_image_from_mask(self): # Make sure binary mask remains a binary mask