Skip to content

Commit 0dcd180

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Allow decoding functions to accept the mode parameter as a string (#8627)
Reviewed By: vmoens Differential Revision: D62581688 fbshipit-source-id: 47f2b7b791148279dcb1d01313be11ecabacb5fb
1 parent 71bafd6 commit 0dcd180

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

test/test_image.py

+11
Original file line numberDiff line numberDiff line change
@@ -1065,5 +1065,16 @@ def test_decode_image_path(input_type, scripted):
10651065
decode_fun(input)
10661066

10671067

1068+
def test_mode_str():
1069+
# Make sure decode_image supports string modes. We just test decode_image,
1070+
# not all of the decoding functions, but they should all support that too.
1071+
# Torchscript fails when passing strings, which is expected.
1072+
path = next(get_images(IMAGE_ROOT, ".png"))
1073+
assert decode_image(path, mode="RGB").shape[0] == 3
1074+
assert decode_image(path, mode="rGb").shape[0] == 3
1075+
assert decode_image(path, mode="GRAY").shape[0] == 1
1076+
assert decode_image(path, mode="RGBA").shape[0] == 4
1077+
1078+
10681079
if __name__ == "__main__":
10691080
pytest.main([__file__])

torchvision/io/image.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class ImageReadMode(Enum):
4040
GRAY_ALPHA = 2
4141
RGB = 3
4242
RGB_ALPHA = 4
43+
RGBA = RGB_ALPHA # Alias for convenience
4344

4445

4546
def read_file(path: str) -> torch.Tensor:
@@ -92,7 +93,7 @@ def decode_png(
9293
Args:
9394
input (Tensor[1]): a one dimensional uint8 tensor containing
9495
the raw bytes of the PNG image.
95-
mode (ImageReadMode): the read mode used for optionally
96+
mode (str or ImageReadMode): the read mode used for optionally
9697
converting the image. Default: ``ImageReadMode.UNCHANGED``.
9798
See `ImageReadMode` class for more information on various
9899
available modes.
@@ -104,6 +105,8 @@ def decode_png(
104105
"""
105106
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
106107
_log_api_usage_once(decode_png)
108+
if isinstance(mode, str):
109+
mode = ImageReadMode[mode.upper()]
107110
output = torch.ops.image.decode_png(input, mode.value, apply_exif_orientation)
108111
return output
109112

@@ -168,7 +171,7 @@ def decode_jpeg(
168171
input (Tensor[1] or list[Tensor[1]]): a (list of) one dimensional uint8 tensor(s) containing
169172
the raw bytes of the JPEG image. The tensor(s) must be on CPU,
170173
regardless of the ``device`` parameter.
171-
mode (ImageReadMode): the read mode used for optionally
174+
mode (str or ImageReadMode): the read mode used for optionally
172175
converting the image(s). The supported modes are: ``ImageReadMode.UNCHANGED``,
173176
``ImageReadMode.GRAY`` and ``ImageReadMode.RGB``
174177
Default: ``ImageReadMode.UNCHANGED``.
@@ -198,6 +201,8 @@ def decode_jpeg(
198201
_log_api_usage_once(decode_jpeg)
199202
if isinstance(device, str):
200203
device = torch.device(device)
204+
if isinstance(mode, str):
205+
mode = ImageReadMode[mode.upper()]
201206

202207
if isinstance(input, list):
203208
if len(input) == 0:
@@ -298,7 +303,7 @@ def decode_image(
298303
input (Tensor or str or ``pathlib.Path``): The image to decode. If a
299304
tensor is passed, it must be one dimensional uint8 tensor containing
300305
the raw bytes of the image. Otherwise, this must be a path to the image file.
301-
mode (ImageReadMode): the read mode used for optionally converting the image.
306+
mode (str or ImageReadMode): the read mode used for optionally converting the image.
302307
Default: ``ImageReadMode.UNCHANGED``.
303308
See ``ImageReadMode`` class for more information on various
304309
available modes. Only applies to JPEG and PNG images.
@@ -312,6 +317,8 @@ def decode_image(
312317
_log_api_usage_once(decode_image)
313318
if not isinstance(input, torch.Tensor):
314319
input = read_file(str(input))
320+
if isinstance(mode, str):
321+
mode = ImageReadMode[mode.upper()]
315322
output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
316323
return output
317324

@@ -360,7 +367,7 @@ def decode_webp(
360367
Args:
361368
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
362369
the raw bytes of the WEBP image.
363-
mode (ImageReadMode): The read mode used for optionally
370+
mode (str or ImageReadMode): The read mode used for optionally
364371
converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
365372
Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
366373
@@ -369,6 +376,8 @@ def decode_webp(
369376
"""
370377
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
371378
_log_api_usage_once(decode_webp)
379+
if isinstance(mode, str):
380+
mode = ImageReadMode[mode.upper()]
372381
return torch.ops.image.decode_webp(input, mode.value)
373382

374383

@@ -389,7 +398,7 @@ def _decode_avif(
389398
Args:
390399
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
391400
the raw bytes of the AVIF image.
392-
mode (ImageReadMode): The read mode used for optionally
401+
mode (str or ImageReadMode): The read mode used for optionally
393402
converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
394403
Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
395404
@@ -398,6 +407,8 @@ def _decode_avif(
398407
"""
399408
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
400409
_log_api_usage_once(_decode_avif)
410+
if isinstance(mode, str):
411+
mode = ImageReadMode[mode.upper()]
401412
return torch.ops.image.decode_avif(input, mode.value)
402413

403414

@@ -415,7 +426,7 @@ def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
415426
Args:
416427
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
417428
the raw bytes of the HEIC image.
418-
mode (ImageReadMode): The read mode used for optionally
429+
mode (str or ImageReadMode): The read mode used for optionally
419430
converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
420431
Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
421432
@@ -424,4 +435,6 @@ def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
424435
"""
425436
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
426437
_log_api_usage_once(_decode_heic)
438+
if isinstance(mode, str):
439+
mode = ImageReadMode[mode.upper()]
427440
return torch.ops.image.decode_heic(input, mode.value)

0 commit comments

Comments
 (0)