Skip to content

Commit

Permalink
Accept RGBA image as input (#2877)
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei authored May 10, 2024
1 parent ae9528b commit 85c7310
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
9 changes: 6 additions & 3 deletions internal_controlnet/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
HiResFixOption,
PuLIDMode,
)
from annotator.util import HWC3


def _unimplemented_func(*args, **kwargs):
Expand Down Expand Up @@ -279,9 +280,11 @@ def parse_image(cls, image) -> np.ndarray:
else:
raise ValueError(f"Unrecognized image format {image}.")

# [H, W] => [H, W, 3]
if np_image.ndim == 2:
np_image = np.stack([np_image, np_image, np_image], axis=-1)
# Convert following image shapes to shape [H, W, C=3].
# - [H, W]
# - [H, W, 1]
# - [H, W, 4]
np_image = HWC3(np_image)
assert np_image.ndim == 3
assert np_image.shape[2] == 3
return np_image
Expand Down
7 changes: 5 additions & 2 deletions unit_tests/args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
img1 = np.ones(shape=[H, W, 3], dtype=np.uint8)
img2 = np.ones(shape=[H, W, 3], dtype=np.uint8) * 2
mask_diff = np.ones(shape=[H - 1, W - 1, 3], dtype=np.uint8) * 2
mask_2d = np.ones(shape=[H, W])
mask_2d = np.ones(shape=[H, W], dtype=np.uint8)
img_bad_channel = np.ones(shape=[H, W, 2], dtype=np.uint8) * 2
img_bad_dim = np.ones(shape=[1, H, W, 3], dtype=np.uint8) * 2
ui_img_diff = np.ones(shape=[H - 1, W - 1, 4], dtype=np.uint8) * 2
Expand Down Expand Up @@ -104,9 +104,12 @@ def test_model_valid(set_cls_funcs):
# UI
dict(image=dict(image=img1)),
dict(image=dict(image=img1, mask=img2)),
# 2D mask should be accepted.
# Grey-scale mask/image should be accepted.
dict(image=dict(image=img1, mask=mask_2d)),
dict(image=img1, mask=mask_2d),
dict(image=np.zeros(shape=[H, W, 1], dtype=np.uint8)),
# RGBA image input should be accepted.
dict(image=np.zeros(shape=[H, W, 4], dtype=np.uint8)),
],
)
def test_valid_image_formats(set_cls_funcs, d):
Expand Down

0 comments on commit 85c7310

Please sign in to comment.