@@ -1470,7 +1470,15 @@ def __call__(
14701470 height , width = self ._default_height_width (height , width , adapter_image )
14711471 device = self ._execution_device
14721472
1473- adapter_input = _preprocess_adapter_image (adapter_image , height , width ).to (device )
1473+ if isinstance (adapter , MultiAdapter ):
1474+ adapter_input = []
1475+ for one_image in adapter_image :
1476+ one_image = _preprocess_adapter_image (one_image , height , width )
1477+ one_image = one_image .to (device = device , dtype = adapter .dtype )
1478+ adapter_input .append (one_image )
1479+ else :
1480+ adapter_input = _preprocess_adapter_image (adapter_image , height , width )
1481+ adapter_input = adapter_input .to (device = device , dtype = adapter .dtype )
14741482
14751483 original_size = original_size or (height , width )
14761484 target_size = target_size or (height , width )
@@ -1643,10 +1651,14 @@ def denoising_value_valid(dnv):
16431651 extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
16441652
16451653 # 10. Prepare added time ids & embeddings & adapter features
1646- adapter_input = adapter_input .type (latents .dtype )
1647- adapter_state = adapter (adapter_input )
1648- for k , v in enumerate (adapter_state ):
1649- adapter_state [k ] = v * adapter_conditioning_scale
1654+ if isinstance (adapter , MultiAdapter ):
1655+ adapter_state = adapter (adapter_input , adapter_conditioning_scale )
1656+ for k , v in enumerate (adapter_state ):
1657+ adapter_state [k ] = v
1658+ else :
1659+ adapter_state = adapter (adapter_input )
1660+ for k , v in enumerate (adapter_state ):
1661+ adapter_state [k ] = v * adapter_conditioning_scale
16501662 if num_images_per_prompt > 1 :
16511663 for k , v in enumerate (adapter_state ):
16521664 adapter_state [k ] = v .repeat (num_images_per_prompt , 1 , 1 , 1 )
0 commit comments