Skip to content

Commit 47d3021

Browse files
affromeroandressayakpaul
authored andcommitted
[bug fix] Inpainting for MultiAdapter (huggingface#5922)
* bug in MultiAdapter for Inpainting * adapter_input is a list for MultiAdapter --------- Co-authored-by: andres <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 70f7bcd commit 47d3021

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,7 +1470,15 @@ def __call__(
14701470
height, width = self._default_height_width(height, width, adapter_image)
14711471
device = self._execution_device
14721472

1473-
adapter_input = _preprocess_adapter_image(adapter_image, height, width).to(device)
1473+
if isinstance(adapter, MultiAdapter):
1474+
adapter_input = []
1475+
for one_image in adapter_image:
1476+
one_image = _preprocess_adapter_image(one_image, height, width)
1477+
one_image = one_image.to(device=device, dtype=adapter.dtype)
1478+
adapter_input.append(one_image)
1479+
else:
1480+
adapter_input = _preprocess_adapter_image(adapter_image, height, width)
1481+
adapter_input = adapter_input.to(device=device, dtype=adapter.dtype)
14741482

14751483
original_size = original_size or (height, width)
14761484
target_size = target_size or (height, width)
@@ -1643,10 +1651,14 @@ def denoising_value_valid(dnv):
16431651
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
16441652

16451653
# 10. Prepare added time ids & embeddings & adapter features
1646-
adapter_input = adapter_input.type(latents.dtype)
1647-
adapter_state = adapter(adapter_input)
1648-
for k, v in enumerate(adapter_state):
1649-
adapter_state[k] = v * adapter_conditioning_scale
1654+
if isinstance(adapter, MultiAdapter):
1655+
adapter_state = adapter(adapter_input, adapter_conditioning_scale)
1656+
for k, v in enumerate(adapter_state):
1657+
adapter_state[k] = v
1658+
else:
1659+
adapter_state = adapter(adapter_input)
1660+
for k, v in enumerate(adapter_state):
1661+
adapter_state[k] = v * adapter_conditioning_scale
16501662
if num_images_per_prompt > 1:
16511663
for k, v in enumerate(adapter_state):
16521664
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)

0 commit comments

Comments
 (0)