@@ -11,11 +11,15 @@ def _is_tensor_a_torch_image(x: Tensor) -> bool:
1111 return x .ndim >= 2
1212
1313
14+ def _assert_image_tensor (img ):
15+ if not _is_tensor_a_torch_image (img ):
16+ raise TypeError ("Tensor is not a torch image." )
17+
18+
1419def _get_image_size (img : Tensor ) -> List [int ]:
1520 """Returns (w, h) of tensor image"""
16- if _is_tensor_a_torch_image (img ):
17- return [img .shape [- 1 ], img .shape [- 2 ]]
18- raise TypeError ("Unexpected input type" )
21+ _assert_image_tensor (img )
22+ return [img .shape [- 1 ], img .shape [- 2 ]]
1923
2024
2125def _get_image_num_channels (img : Tensor ) -> int :
@@ -143,8 +147,7 @@ def vflip(img: Tensor) -> Tensor:
143147 Returns:
144148 Tensor: Vertically flipped image Tensor.
145149 """
146- if not _is_tensor_a_torch_image (img ):
147- raise TypeError ('tensor is not a torch image.' )
150+ _assert_image_tensor (img )
148151
149152 return img .flip (- 2 )
150153
@@ -163,8 +166,7 @@ def hflip(img: Tensor) -> Tensor:
163166 Returns:
164167 Tensor: Horizontally flipped image Tensor.
165168 """
166- if not _is_tensor_a_torch_image (img ):
167- raise TypeError ('tensor is not a torch image.' )
169+ _assert_image_tensor (img )
168170
169171 return img .flip (- 1 )
170172
@@ -187,8 +189,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
187189 Returns:
188190 Tensor: Cropped image.
189191 """
190- if not _is_tensor_a_torch_image (img ):
191- raise TypeError ("tensor is not a torch image." )
192+ _assert_image_tensor (img )
192193
193194 return img [..., top :top + height , left :left + width ]
194195
@@ -254,8 +255,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
254255 if brightness_factor < 0 :
255256 raise ValueError ('brightness_factor ({}) is not non-negative.' .format (brightness_factor ))
256257
257- if not _is_tensor_a_torch_image (img ):
258- raise TypeError ('tensor is not a torch image.' )
258+ _assert_image_tensor (img )
259259
260260 _assert_channels (img , [1 , 3 ])
261261
@@ -282,8 +282,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
282282 if contrast_factor < 0 :
283283 raise ValueError ('contrast_factor ({}) is not non-negative.' .format (contrast_factor ))
284284
285- if not _is_tensor_a_torch_image (img ):
286- raise TypeError ('tensor is not a torch image.' )
285+ _assert_image_tensor (img )
287286
288287 _assert_channels (img , [3 ])
289288
@@ -326,9 +325,11 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
326325 if not (- 0.5 <= hue_factor <= 0.5 ):
327326 raise ValueError ('hue_factor ({}) is not in [-0.5, 0.5].' .format (hue_factor ))
328327
329- if not (isinstance (img , torch .Tensor ) and _is_tensor_a_torch_image ( img ) ):
328+ if not (isinstance (img , torch .Tensor )):
330329 raise TypeError ('Input img should be Tensor image' )
331330
331+ _assert_image_tensor (img )
332+
332333 _assert_channels (img , [3 ])
333334
334335 orig_dtype = img .dtype
@@ -367,8 +368,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
367368 if saturation_factor < 0 :
368369 raise ValueError ('saturation_factor ({}) is not non-negative.' .format (saturation_factor ))
369370
370- if not _is_tensor_a_torch_image (img ):
371- raise TypeError ('tensor is not a torch image.' )
371+ _assert_image_tensor (img )
372372
373373 _assert_channels (img , [3 ])
374374
@@ -447,8 +447,7 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
447447 "Please, use ``F.center_crop`` instead."
448448 )
449449
450- if not _is_tensor_a_torch_image (img ):
451- raise TypeError ('tensor is not a torch image.' )
450+ _assert_image_tensor (img )
452451
453452 _ , image_width , image_height = img .size ()
454453 crop_height , crop_width = output_size
@@ -497,8 +496,7 @@ def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
497496 "Please, use ``F.five_crop`` instead."
498497 )
499498
500- if not _is_tensor_a_torch_image (img ):
501- raise TypeError ('tensor is not a torch image.' )
499+ _assert_image_tensor (img )
502500
503501 assert len (size ) == 2 , "Please provide only two dimensions (h, w) for size."
504502
@@ -553,8 +551,7 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa
553551 "Please, use ``F.ten_crop`` instead."
554552 )
555553
556- if not _is_tensor_a_torch_image (img ):
557- raise TypeError ('tensor is not a torch image.' )
554+ _assert_image_tensor (img )
558555
559556 assert len (size ) == 2 , "Please provide only two dimensions (h, w) for size."
560557 first_five = five_crop (img , size )
@@ -703,8 +700,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
703700 Returns:
704701 Tensor: Padded image.
705702 """
706- if not _is_tensor_a_torch_image (img ):
707- raise TypeError ("tensor is not a torch image." )
703+ _assert_image_tensor (img )
708704
709705 if not isinstance (padding , (int , tuple , list )):
710706 raise TypeError ("Got inappropriate padding arg" )
@@ -796,8 +792,7 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Ten
796792 Returns:
797793 Tensor: Resized image.
798794 """
799- if not _is_tensor_a_torch_image (img ):
800- raise TypeError ("tensor is not a torch image." )
795+ _assert_image_tensor (img )
801796
802797 if not isinstance (size , (int , tuple , list )):
803798 raise TypeError ("Got inappropriate size arg" )
@@ -855,8 +850,11 @@ def _assert_grid_transform_inputs(
855850 supported_interpolation_modes : List [str ],
856851 coeffs : Optional [List [float ]] = None ,
857852):
858- if not (isinstance (img , torch .Tensor ) and _is_tensor_a_torch_image (img )):
859- raise TypeError ("Input img should be Tensor Image" )
853+
854+ if not (isinstance (img , torch .Tensor )):
855+ raise TypeError ("Input img should be Tensor" )
856+
857+ _assert_image_tensor (img )
860858
861859 if matrix is not None and not isinstance (matrix , list ):
862860 raise TypeError ("Argument matrix should be a list" )
@@ -1112,8 +1110,11 @@ def perspective(
11121110 Returns:
11131111 Tensor: transformed image.
11141112 """
1115- if not (isinstance (img , torch .Tensor ) and _is_tensor_a_torch_image (img )):
1116- raise TypeError ('Input img should be Tensor Image' )
1113+
1114+ if not (isinstance (img , torch .Tensor )):
1115+ raise TypeError ('Input img should be Tensor.' )
1116+
1117+ _assert_image_tensor (img )
11171118
11181119 _assert_grid_transform_inputs (
11191120 img ,
@@ -1165,8 +1166,11 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
11651166 Returns:
11661167 Tensor: An image that is blurred using gaussian kernel of given parameters
11671168 """
1168- if not (isinstance (img , torch .Tensor ) or _is_tensor_a_torch_image (img )):
1169- raise TypeError ('img should be Tensor Image. Got {}' .format (type (img )))
1169+
1170+ if not (isinstance (img , torch .Tensor )):
1171+ raise TypeError ('img should be Tensor. Got {}' .format (type (img )))
1172+
1173+ _assert_image_tensor (img )
11701174
11711175 dtype = img .dtype if torch .is_floating_point (img ) else torch .float32
11721176 kernel = _get_gaussian_kernel2d (kernel_size , sigma , dtype = dtype , device = img .device )
@@ -1184,8 +1188,8 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
11841188
11851189
11861190def invert (img : Tensor ) -> Tensor :
1187- if not _is_tensor_a_torch_image ( img ):
1188- raise TypeError ( 'tensor is not a torch image.' )
1191+
1192+ _assert_image_tensor ( img )
11891193
11901194 if img .ndim < 3 :
11911195 raise TypeError ("Input image tensor should have at least 3 dimensions, but found {}" .format (img .ndim ))
@@ -1197,8 +1201,8 @@ def invert(img: Tensor) -> Tensor:
11971201
11981202
11991203def posterize (img : Tensor , bits : int ) -> Tensor :
1200- if not _is_tensor_a_torch_image ( img ):
1201- raise TypeError ( 'tensor is not a torch image.' )
1204+
1205+ _assert_image_tensor ( img )
12021206
12031207 if img .ndim < 3 :
12041208 raise TypeError ("Input image tensor should have at least 3 dimensions, but found {}" .format (img .ndim ))
@@ -1211,8 +1215,8 @@ def posterize(img: Tensor, bits: int) -> Tensor:
12111215
12121216
12131217def solarize (img : Tensor , threshold : float ) -> Tensor :
1214- if not _is_tensor_a_torch_image ( img ):
1215- raise TypeError ( 'tensor is not a torch image.' )
1218+
1219+ _assert_image_tensor ( img )
12161220
12171221 if img .ndim < 3 :
12181222 raise TypeError ("Input image tensor should have at least 3 dimensions, but found {}" .format (img .ndim ))
@@ -1245,8 +1249,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
12451249 if sharpness_factor < 0 :
12461250 raise ValueError ('sharpness_factor ({}) is not non-negative.' .format (sharpness_factor ))
12471251
1248- if not _is_tensor_a_torch_image (img ):
1249- raise TypeError ('tensor is not a torch image.' )
1252+ _assert_image_tensor (img )
12501253
12511254 _assert_channels (img , [1 , 3 ])
12521255
@@ -1257,8 +1260,8 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
12571260
12581261
12591262def autocontrast (img : Tensor ) -> Tensor :
1260- if not _is_tensor_a_torch_image ( img ):
1261- raise TypeError ( 'tensor is not a torch image.' )
1263+
1264+ _assert_image_tensor ( img )
12621265
12631266 if img .ndim < 3 :
12641267 raise TypeError ("Input image tensor should have at least 3 dimensions, but found {}" .format (img .ndim ))
@@ -1297,8 +1300,8 @@ def _equalize_single_image(img: Tensor) -> Tensor:
12971300
12981301
12991302def equalize (img : Tensor ) -> Tensor :
1300- if not _is_tensor_a_torch_image ( img ):
1301- raise TypeError ( 'tensor is not a torch image.' )
1303+
1304+ _assert_image_tensor ( img )
13021305
13031306 if not (3 <= img .ndim <= 4 ):
13041307 raise TypeError ("Input image tensor should have 3 or 4 dimensions, but found {}" .format (img .ndim ))
0 commit comments