@@ -40,6 +40,7 @@ class ImageReadMode(Enum):
40
40
GRAY_ALPHA = 2
41
41
RGB = 3
42
42
RGB_ALPHA = 4
43
+ RGBA = RGB_ALPHA # Alias for convenience
43
44
44
45
45
46
def read_file (path : str ) -> torch .Tensor :
@@ -92,7 +93,7 @@ def decode_png(
92
93
Args:
93
94
input (Tensor[1]): a one dimensional uint8 tensor containing
94
95
the raw bytes of the PNG image.
95
- mode (ImageReadMode): the read mode used for optionally
96
+ mode (str or ImageReadMode): the read mode used for optionally
96
97
converting the image. Default: ``ImageReadMode.UNCHANGED``.
97
98
See `ImageReadMode` class for more information on various
98
99
available modes.
@@ -104,6 +105,8 @@ def decode_png(
104
105
"""
105
106
if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
106
107
_log_api_usage_once (decode_png )
108
+ if isinstance (mode , str ):
109
+ mode = ImageReadMode [mode .upper ()]
107
110
output = torch .ops .image .decode_png (input , mode .value , apply_exif_orientation )
108
111
return output
109
112
@@ -168,7 +171,7 @@ def decode_jpeg(
168
171
input (Tensor[1] or list[Tensor[1]]): a (list of) one dimensional uint8 tensor(s) containing
169
172
the raw bytes of the JPEG image. The tensor(s) must be on CPU,
170
173
regardless of the ``device`` parameter.
171
- mode (ImageReadMode): the read mode used for optionally
174
+ mode (str or ImageReadMode): the read mode used for optionally
172
175
converting the image(s). The supported modes are: ``ImageReadMode.UNCHANGED``,
173
176
``ImageReadMode.GRAY`` and ``ImageReadMode.RGB``
174
177
Default: ``ImageReadMode.UNCHANGED``.
@@ -198,6 +201,8 @@ def decode_jpeg(
198
201
_log_api_usage_once (decode_jpeg )
199
202
if isinstance (device , str ):
200
203
device = torch .device (device )
204
+ if isinstance (mode , str ):
205
+ mode = ImageReadMode [mode .upper ()]
201
206
202
207
if isinstance (input , list ):
203
208
if len (input ) == 0 :
@@ -298,7 +303,7 @@ def decode_image(
298
303
input (Tensor or str or ``pathlib.Path``): The image to decode. If a
299
304
tensor is passed, it must be one dimensional uint8 tensor containing
300
305
the raw bytes of the image. Otherwise, this must be a path to the image file.
301
- mode (ImageReadMode): the read mode used for optionally converting the image.
306
+ mode (str or ImageReadMode): the read mode used for optionally converting the image.
302
307
Default: ``ImageReadMode.UNCHANGED``.
303
308
See ``ImageReadMode`` class for more information on various
304
309
available modes. Only applies to JPEG and PNG images.
@@ -312,6 +317,8 @@ def decode_image(
312
317
_log_api_usage_once (decode_image )
313
318
if not isinstance (input , torch .Tensor ):
314
319
input = read_file (str (input ))
320
+ if isinstance (mode , str ):
321
+ mode = ImageReadMode [mode .upper ()]
315
322
output = torch .ops .image .decode_image (input , mode .value , apply_exif_orientation )
316
323
return output
317
324
@@ -360,7 +367,7 @@ def decode_webp(
360
367
Args:
361
368
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
362
369
the raw bytes of the WEBP image.
363
- mode (ImageReadMode): The read mode used for optionally
370
+ mode (str or ImageReadMode): The read mode used for optionally
364
371
converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
365
372
Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
366
373
@@ -369,6 +376,8 @@ def decode_webp(
369
376
"""
370
377
if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
371
378
_log_api_usage_once (decode_webp )
379
+ if isinstance (mode , str ):
380
+ mode = ImageReadMode [mode .upper ()]
372
381
return torch .ops .image .decode_webp (input , mode .value )
373
382
374
383
@@ -389,7 +398,7 @@ def _decode_avif(
389
398
Args:
390
399
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
391
400
the raw bytes of the AVIF image.
392
- mode (ImageReadMode): The read mode used for optionally
401
+ mode (str or ImageReadMode): The read mode used for optionally
393
402
converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
394
403
Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
395
404
@@ -398,6 +407,8 @@ def _decode_avif(
398
407
"""
399
408
if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
400
409
_log_api_usage_once (_decode_avif )
410
+ if isinstance (mode , str ):
411
+ mode = ImageReadMode [mode .upper ()]
401
412
return torch .ops .image .decode_avif (input , mode .value )
402
413
403
414
@@ -415,7 +426,7 @@ def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
415
426
Args:
416
427
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
417
428
the raw bytes of the HEIC image.
418
- mode (ImageReadMode): The read mode used for optionally
429
+ mode (str or ImageReadMode): The read mode used for optionally
419
430
converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
420
431
Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
421
432
@@ -424,4 +435,6 @@ def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
424
435
"""
425
436
if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
426
437
_log_api_usage_once (_decode_heic )
438
+ if isinstance (mode , str ):
439
+ mode = ImageReadMode [mode .upper ()]
427
440
return torch .ops .image .decode_heic (input , mode .value )
0 commit comments