From 85c731070f98e914be8de96456a4559da523e52e Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Fri, 10 May 2024 14:18:59 -0400 Subject: [PATCH] Accept RGBA image as input (#2877) --- internal_controlnet/args.py | 9 ++++++--- unit_tests/args_test.py | 7 +++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/internal_controlnet/args.py b/internal_controlnet/args.py index f98b99be7..05dcb6d73 100644 --- a/internal_controlnet/args.py +++ b/internal_controlnet/args.py @@ -16,6 +16,7 @@ HiResFixOption, PuLIDMode, ) +from annotator.util import HWC3 def _unimplemented_func(*args, **kwargs): @@ -279,9 +280,11 @@ def parse_image(cls, image) -> np.ndarray: else: raise ValueError(f"Unrecognized image format {image}.") - # [H, W] => [H, W, 3] - if np_image.ndim == 2: - np_image = np.stack([np_image, np_image, np_image], axis=-1) + # Convert following image shapes to shape [H, W, C=3]. + # - [H, W] + # - [H, W, 1] + # - [H, W, 4] + np_image = HWC3(np_image) assert np_image.ndim == 3 assert np_image.shape[2] == 3 return np_image diff --git a/unit_tests/args_test.py b/unit_tests/args_test.py index 0c053f73f..31d79eeaa 100644 --- a/unit_tests/args_test.py +++ b/unit_tests/args_test.py @@ -10,7 +10,7 @@ img1 = np.ones(shape=[H, W, 3], dtype=np.uint8) img2 = np.ones(shape=[H, W, 3], dtype=np.uint8) * 2 mask_diff = np.ones(shape=[H - 1, W - 1, 3], dtype=np.uint8) * 2 -mask_2d = np.ones(shape=[H, W]) +mask_2d = np.ones(shape=[H, W], dtype=np.uint8) img_bad_channel = np.ones(shape=[H, W, 2], dtype=np.uint8) * 2 img_bad_dim = np.ones(shape=[1, H, W, 3], dtype=np.uint8) * 2 ui_img_diff = np.ones(shape=[H - 1, W - 1, 4], dtype=np.uint8) * 2 @@ -104,9 +104,12 @@ def test_model_valid(set_cls_funcs): # UI dict(image=dict(image=img1)), dict(image=dict(image=img1, mask=img2)), - # 2D mask should be accepted. + # Grey-scale mask/image should be accepted. dict(image=dict(image=img1, mask=mask_2d)), dict(image=img1, mask=mask_2d), + dict(image=np.zeros(shape=[H, W, 1], dtype=np.uint8)), + # RGBA image input should be accepted. + dict(image=np.zeros(shape=[H, W, 4], dtype=np.uint8)), ], ) def test_valid_image_formats(set_cls_funcs, d):