@@ -661,7 +661,7 @@ def test_1_channel_float_tensor_to_pil_image(self):
661
661
@pytest .mark .parametrize (
662
662
"img_data, expected_mode" ,
663
663
[
664
- (torch .Tensor (4 , 4 , 1 ).uniform_ ().numpy (), "F " ),
664
+ (torch .Tensor (4 , 4 , 1 ).uniform_ ().numpy (), "L " ),
665
665
(torch .ByteTensor (4 , 4 , 1 ).random_ (0 , 255 ).numpy (), "L" ),
666
666
(torch .ShortTensor (4 , 4 , 1 ).random_ ().numpy (), "I;16" ),
667
667
(torch .IntTensor (4 , 4 , 1 ).random_ ().numpy (), "I" ),
@@ -671,6 +671,8 @@ def test_1_channel_ndarray_to_pil_image(self, with_mode, img_data, expected_mode
671
671
transform = transforms .ToPILImage (mode = expected_mode ) if with_mode else transforms .ToPILImage ()
672
672
img = transform (img_data )
673
673
assert img .mode == expected_mode
674
+ if np .issubdtype (img_data .dtype , np .floating ):
675
+ img_data = (img_data * 255 ).astype (np .uint8 )
674
676
# note: we explicitly convert img's dtype because pytorch doesn't support uint16
675
677
# and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array
676
678
torch .testing .assert_close (img_data [:, :, 0 ], np .asarray (img ).astype (img_data .dtype ))
@@ -741,7 +743,7 @@ def test_2d_tensor_to_pil_image(self, with_mode, img_data, expected_output, expe
741
743
@pytest .mark .parametrize (
742
744
"img_data, expected_mode" ,
743
745
[
744
- (torch .Tensor (4 , 4 ).uniform_ ().numpy (), "F " ),
746
+ (torch .Tensor (4 , 4 ).uniform_ ().numpy (), "L " ),
745
747
(torch .ByteTensor (4 , 4 ).random_ (0 , 255 ).numpy (), "L" ),
746
748
(torch .ShortTensor (4 , 4 ).random_ ().numpy (), "I;16" ),
747
749
(torch .IntTensor (4 , 4 ).random_ ().numpy (), "I" ),
@@ -751,6 +753,8 @@ def test_2d_ndarray_to_pil_image(self, with_mode, img_data, expected_mode):
751
753
transform = transforms .ToPILImage (mode = expected_mode ) if with_mode else transforms .ToPILImage ()
752
754
img = transform (img_data )
753
755
assert img .mode == expected_mode
756
+ if np .issubdtype (img_data .dtype , np .floating ):
757
+ img_data = (img_data * 255 ).astype (np .uint8 )
754
758
np .testing .assert_allclose (img_data , img )
755
759
756
760
@pytest .mark .parametrize ("expected_mode" , [None , "RGB" , "HSV" , "YCbCr" ])
@@ -874,8 +878,6 @@ def test_ndarray_bad_types_to_pil_image(self):
874
878
trans (np .ones ([4 , 4 , 1 ], np .uint16 ))
875
879
with pytest .raises (TypeError , match = reg_msg ):
876
880
trans (np .ones ([4 , 4 , 1 ], np .uint32 ))
877
- with pytest .raises (TypeError , match = reg_msg ):
878
- trans (np .ones ([4 , 4 , 1 ], np .float64 ))
879
881
880
882
with pytest .raises (ValueError , match = r"pic should be 2/3 dimensional. Got \d+ dimensions." ):
881
883
transforms .ToPILImage ()(np .ones ([1 , 4 , 4 , 3 ]))
0 commit comments