Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/transformers/image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,20 @@ def to_pil_image(
# If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image

# PIL.Image can only store uint8 values, so we rescale the image to be between 0 and 255 if needed.
# PIL.Image can only store uint8 values so we rescale the image to be between 0 and 255 if needed.
if do_rescale is None:
if np.all(0 <= image) and np.all(image <= 1):
do_rescale = True
elif np.allclose(image, image.astype(int)):
if image.dtype == np.uint8:
do_rescale = False
elif np.allclose(image, image.astype(int)):
if np.all(0 <= image) and np.all(image <= 255):
do_rescale = False
else:
raise ValueError(
"The image to be converted to a PIL image contains values outside the range [0, 255], "
f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
)
elif np.all(0 <= image) and np.all(image <= 1):
do_rescale = True
else:
raise ValueError(
"The image to be converted to a PIL image contains values outside the range [0, 1], "
Expand Down
23 changes: 22 additions & 1 deletion tests/test_image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,27 @@ def test_to_pil_image_from_float(self, name, image_shape, dtype):
with self.assertRaises(ValueError):
to_pil_image(image)

@require_vision
def test_to_pil_image_from_mask(self):
# Make sure binary mask remains a binary mask
image = np.random.randint(0, 2, (3, 4, 5)).astype(np.uint8)
pil_image = to_pil_image(image)
self.assertIsInstance(pil_image, PIL.Image.Image)
self.assertEqual(pil_image.size, (5, 4))

np_img = np.asarray(pil_image)
self.assertTrue(np_img.min() == 0)
self.assertTrue(np_img.max() == 1)

image = np.random.randint(0, 2, (3, 4, 5)).astype(np.float32)
pil_image = to_pil_image(image)
self.assertIsInstance(pil_image, PIL.Image.Image)
self.assertEqual(pil_image.size, (5, 4))

np_img = np.asarray(pil_image)
self.assertTrue(np_img.min() == 0)
self.assertTrue(np_img.max() == 1)

@require_tf
def test_to_pil_image_from_tensorflow(self):
# channels_first
Expand Down Expand Up @@ -222,7 +243,7 @@ def test_resize(self):
self.assertIsInstance(resized_image, np.ndarray)
self.assertEqual(resized_image.shape, (30, 40, 3))

# Check PIL.Image.Image is return if return_numpy=False
# Check PIL.Image.Image is returned if return_numpy=False
resized_image = resize(image, (30, 40), return_numpy=False)
self.assertIsInstance(resized_image, PIL.Image.Image)
# PIL size is in (width, height) order
Expand Down