From 4241539d6c39f8ccf75869a81f046f6f98e527c7 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Mon, 6 Mar 2023 14:15:30 +0000 Subject: [PATCH 1/3] Add check before int casting for PIL conversion --- src/transformers/image_transforms.py | 13 +++++++++++-- tests/test_image_transforms.py | 5 +++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index 0ae19c43c74a..d9fd6bf4cac2 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -131,7 +131,8 @@ def to_pil_image( The image to convert to the `PIL.Image` format. do_rescale (`bool`, *optional*): Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default - to `True` if the image type is a floating type, `False` otherwise. + to `True` if the image type is a floating type and casting to `int` would result in a loss of precision, and + `False` otherwise. Returns: `PIL.Image.Image`: The converted image. @@ -156,9 +157,17 @@ def to_pil_image( image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image # PIL.Image can only store uint8 values, so we rescale the image to be between 0 and 255 if needed. - do_rescale = isinstance(image.flat[0], (float, np.float32, np.float64)) if do_rescale is None else do_rescale + if do_rescale is None: + do_rescale = isinstance(image.flat[0], (float, np.floating)) and not np.allclose(image, image.astype(int)) + if do_rescale: image = rescale(image, 255) + + if np.any(image < 0) or np.any(image > 255): + raise ValueError( + "The image to be converted to a PIL image contains values outside the range [0, 255], " + f"got [{image.min()}, {image.max()}] which cannot be converted to uint8." + ) image = image.astype(np.uint8) return PIL.Image.fromarray(image) diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index 2287fdbf2ce9..0efefc7c8fbb 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -96,6 +96,11 @@ def test_to_pil_image_from_float(self, name, image_shape, dtype): # make sure image is correctly rescaled self.assertTrue(np.abs(np.asarray(pil_image)).sum() > 0) + # Make sure that an exception is raised if image is not in [0, 1] + image = np.random.randn(*image_shape).astype(dtype) + with self.assertRaises(ValueError): + to_pil_image(image) + @require_tf def test_to_pil_image_from_tensorflow(self): # channels_first From c33f9fde9e6a9f9116e5d0f6ccb59aa89039f579 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Mon, 6 Mar 2023 15:10:56 +0000 Subject: [PATCH 2/3] Line length --- src/transformers/image_transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index d9fd6bf4cac2..b679e0bc3aad 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -131,8 +131,8 @@ def to_pil_image( The image to convert to the `PIL.Image` format. do_rescale (`bool`, *optional*): Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default - to `True` if the image type is a floating type and casting to `int` would result in a loss of precision, and - `False` otherwise. + to `True` if the image type is a floating type and casting to `int` would result in a loss of precision, + and `False` otherwise. Returns: `PIL.Image.Image`: The converted image. From 2304eb6e0fd6cd0f19fbbe24cfd7e63fd6913ec1 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Tue, 7 Mar 2023 10:55:24 +0000 Subject: [PATCH 3/3] Tidier logic --- src/transformers/image_transforms.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index b679e0bc3aad..16016d97042c 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -158,16 +158,19 @@ def to_pil_image( # PIL.Image can only store uint8 values, so we rescale the image to be between 0 and 255 if needed. if do_rescale is None: - do_rescale = isinstance(image.flat[0], (float, np.floating)) and not np.allclose(image, image.astype(int)) + if np.all(0 <= image) and np.all(image <= 1): + do_rescale = True + elif np.allclose(image, image.astype(int)): + do_rescale = False + else: + raise ValueError( + "The image to be converted to a PIL image contains values outside the range [0, 1], " + f"got [{image.min()}, {image.max()}] which cannot be converted to uint8." + ) if do_rescale: image = rescale(image, 255) - if np.any(image < 0) or np.any(image > 255): - raise ValueError( - "The image to be converted to a PIL image contains values outside the range [0, 255], " - f"got [{image.min()}, {image.max()}] which cannot be converted to uint8." - ) image = image.astype(np.uint8) return PIL.Image.fromarray(image)