Skip to content

Commit 15c166a

Browse files
pmeierNicolasHug
andauthored
refactor to_pil_image and align array with tensor inputs (#8097)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent a0fcd08 commit 15c166a

File tree

2 files changed

+20
-33
lines changed

2 files changed

+20
-33
lines changed

test/test_transforms.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ def test_1_channel_float_tensor_to_pil_image(self):
661661
@pytest.mark.parametrize(
662662
"img_data, expected_mode",
663663
[
664-
(torch.Tensor(4, 4, 1).uniform_().numpy(), "F"),
664+
(torch.Tensor(4, 4, 1).uniform_().numpy(), "L"),
665665
(torch.ByteTensor(4, 4, 1).random_(0, 255).numpy(), "L"),
666666
(torch.ShortTensor(4, 4, 1).random_().numpy(), "I;16"),
667667
(torch.IntTensor(4, 4, 1).random_().numpy(), "I"),
@@ -671,6 +671,8 @@ def test_1_channel_ndarray_to_pil_image(self, with_mode, img_data, expected_mode
671671
transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
672672
img = transform(img_data)
673673
assert img.mode == expected_mode
674+
if np.issubdtype(img_data.dtype, np.floating):
675+
img_data = (img_data * 255).astype(np.uint8)
674676
# note: we explicitly convert img's dtype because pytorch doesn't support uint16
675677
# and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array
676678
torch.testing.assert_close(img_data[:, :, 0], np.asarray(img).astype(img_data.dtype))
@@ -741,7 +743,7 @@ def test_2d_tensor_to_pil_image(self, with_mode, img_data, expected_output, expe
741743
@pytest.mark.parametrize(
742744
"img_data, expected_mode",
743745
[
744-
(torch.Tensor(4, 4).uniform_().numpy(), "F"),
746+
(torch.Tensor(4, 4).uniform_().numpy(), "L"),
745747
(torch.ByteTensor(4, 4).random_(0, 255).numpy(), "L"),
746748
(torch.ShortTensor(4, 4).random_().numpy(), "I;16"),
747749
(torch.IntTensor(4, 4).random_().numpy(), "I"),
@@ -751,6 +753,8 @@ def test_2d_ndarray_to_pil_image(self, with_mode, img_data, expected_mode):
751753
transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage()
752754
img = transform(img_data)
753755
assert img.mode == expected_mode
756+
if np.issubdtype(img_data.dtype, np.floating):
757+
img_data = (img_data * 255).astype(np.uint8)
754758
np.testing.assert_allclose(img_data, img)
755759

756760
@pytest.mark.parametrize("expected_mode", [None, "RGB", "HSV", "YCbCr"])
@@ -874,8 +878,6 @@ def test_ndarray_bad_types_to_pil_image(self):
874878
trans(np.ones([4, 4, 1], np.uint16))
875879
with pytest.raises(TypeError, match=reg_msg):
876880
trans(np.ones([4, 4, 1], np.uint32))
877-
with pytest.raises(TypeError, match=reg_msg):
878-
trans(np.ones([4, 4, 1], np.float64))
879881

880882
with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."):
881883
transforms.ToPILImage()(np.ones([1, 4, 4, 3]))

torchvision/transforms/functional.py

+14-29
Original file line numberDiff line numberDiff line change
@@ -258,41 +258,26 @@ def to_pil_image(pic, mode=None):
258258
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
259259
_log_api_usage_once(to_pil_image)
260260

261-
if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
261+
if isinstance(pic, torch.Tensor):
262+
if pic.ndim == 3:
263+
pic = pic.permute((1, 2, 0))
264+
pic = pic.numpy(force=True)
265+
elif not isinstance(pic, np.ndarray):
262266
raise TypeError(f"pic should be Tensor or ndarray. Got {type(pic)}.")
263267

264-
elif isinstance(pic, torch.Tensor):
265-
if pic.ndimension() not in {2, 3}:
266-
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndimension()} dimensions.")
267-
268-
elif pic.ndimension() == 2:
269-
# if 2D image, add channel dimension (CHW)
270-
pic = pic.unsqueeze(0)
271-
272-
# check number of channels
273-
if pic.shape[-3] > 4:
274-
raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-3]} channels.")
275-
276-
elif isinstance(pic, np.ndarray):
277-
if pic.ndim not in {2, 3}:
278-
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
279-
280-
elif pic.ndim == 2:
281-
# if 2D image, add channel dimension (HWC)
282-
pic = np.expand_dims(pic, 2)
268+
if pic.ndim == 2:
269+
# if 2D image, add channel dimension (HWC)
270+
pic = np.expand_dims(pic, 2)
271+
if pic.ndim != 3:
272+
raise ValueError(f"pic should be 2/3 dimensional. Got {pic.ndim} dimensions.")
283273

284-
# check number of channels
285-
if pic.shape[-1] > 4:
286-
raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.")
274+
if pic.shape[-1] > 4:
275+
raise ValueError(f"pic should not have > 4 channels. Got {pic.shape[-1]} channels.")
287276

288277
npimg = pic
289-
if isinstance(pic, torch.Tensor):
290-
if pic.is_floating_point() and mode != "F":
291-
pic = pic.mul(255).byte()
292-
npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0))
293278

294-
if not isinstance(npimg, np.ndarray):
295-
raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray, not {type(npimg)}")
279+
if np.issubdtype(npimg.dtype, np.floating) and mode != "F":
280+
npimg = (npimg * 255).astype(np.uint8)
296281

297282
if npimg.shape[2] == 1:
298283
expected_mode = None

0 commit comments

Comments
 (0)