1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+
1415from pathlib import Path
15- from typing import Dict , Union
16+ from typing import Dict , List , Union
1617
1718import torch
1819from huggingface_hub .utils import validate_hf_hub_args
@@ -45,9 +46,9 @@ class IPAdapterMixin:
4546 @validate_hf_hub_args
4647 def load_ip_adapter (
4748 self ,
48- pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
49- subfolder : str ,
50- weight_name : str ,
49+ pretrained_model_name_or_path_or_dict : Union [str , List [ str ], Dict [str , torch .Tensor ]],
50+ subfolder : Union [ str , List [ str ]] ,
51+ weight_name : Union [ str , List [ str ]] ,
5152 ** kwargs ,
5253 ):
5354 """
@@ -87,6 +88,26 @@ def load_ip_adapter(
8788 The subfolder location of a model file within a larger model repository on the Hub or locally.
8889 """
8990
91+ # handle the list inputs for multiple IP Adapters
92+ if not isinstance (weight_name , list ):
93+ weight_name = [weight_name ]
94+
95+ if not isinstance (pretrained_model_name_or_path_or_dict , list ):
96+ pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict ]
97+ if len (pretrained_model_name_or_path_or_dict ) == 1 :
98+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len (weight_name )
99+
100+ if not isinstance (subfolder , list ):
101+ subfolder = [subfolder ]
102+ if len (subfolder ) == 1 :
103+ subfolder = subfolder * len (weight_name )
104+
105+ if len (weight_name ) != len (pretrained_model_name_or_path_or_dict ):
106+ raise ValueError ("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length." )
107+
108+ if len (weight_name ) != len (subfolder ):
109+ raise ValueError ("`weight_name` and `subfolder` must have the same length." )
110+
90111 # Load the main state dict first.
91112 cache_dir = kwargs .pop ("cache_dir" , None )
92113 force_download = kwargs .pop ("force_download" , False )
@@ -100,61 +121,68 @@ def load_ip_adapter(
100121 "file_type" : "attn_procs_weights" ,
101122 "framework" : "pytorch" ,
102123 }
103-
104- if not isinstance (pretrained_model_name_or_path_or_dict , dict ):
105- model_file = _get_model_file (
106- pretrained_model_name_or_path_or_dict ,
107- weights_name = weight_name ,
108- cache_dir = cache_dir ,
109- force_download = force_download ,
110- resume_download = resume_download ,
111- proxies = proxies ,
112- local_files_only = local_files_only ,
113- token = token ,
114- revision = revision ,
115- subfolder = subfolder ,
116- user_agent = user_agent ,
117- )
118- if weight_name .endswith (".safetensors" ):
119- state_dict = {"image_proj" : {}, "ip_adapter" : {}}
120- with safe_open (model_file , framework = "pt" , device = "cpu" ) as f :
121- for key in f .keys ():
122- if key .startswith ("image_proj." ):
123- state_dict ["image_proj" ][key .replace ("image_proj." , "" )] = f .get_tensor (key )
124- elif key .startswith ("ip_adapter." ):
125- state_dict ["ip_adapter" ][key .replace ("ip_adapter." , "" )] = f .get_tensor (key )
126- else :
127- state_dict = torch .load (model_file , map_location = "cpu" )
128- else :
129- state_dict = pretrained_model_name_or_path_or_dict
130-
131- keys = list (state_dict .keys ())
132- if keys != ["image_proj" , "ip_adapter" ]:
133- raise ValueError ("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict." )
134-
135- # load CLIP image encoder here if it has not been registered to the pipeline yet
136- if hasattr (self , "image_encoder" ) and getattr (self , "image_encoder" , None ) is None :
124+ state_dicts = []
125+ for pretrained_model_name_or_path_or_dict , weight_name , subfolder in zip (
126+ pretrained_model_name_or_path_or_dict , weight_name , subfolder
127+ ):
137128 if not isinstance (pretrained_model_name_or_path_or_dict , dict ):
138- logger .info (f"loading image_encoder from { pretrained_model_name_or_path_or_dict } " )
139- image_encoder = CLIPVisionModelWithProjection .from_pretrained (
129+ model_file = _get_model_file (
140130 pretrained_model_name_or_path_or_dict ,
141- subfolder = Path (subfolder , "image_encoder" ).as_posix (),
142- ).to (self .device , dtype = self .dtype )
143- self .image_encoder = image_encoder
144- self .register_to_config (image_encoder = ["transformers" , "CLIPVisionModelWithProjection" ])
131+ weights_name = weight_name ,
132+ cache_dir = cache_dir ,
133+ force_download = force_download ,
134+ resume_download = resume_download ,
135+ proxies = proxies ,
136+ local_files_only = local_files_only ,
137+ token = token ,
138+ revision = revision ,
139+ subfolder = subfolder ,
140+ user_agent = user_agent ,
141+ )
142+ if weight_name .endswith (".safetensors" ):
143+ state_dict = {"image_proj" : {}, "ip_adapter" : {}}
144+ with safe_open (model_file , framework = "pt" , device = "cpu" ) as f :
145+ for key in f .keys ():
146+ if key .startswith ("image_proj." ):
147+ state_dict ["image_proj" ][key .replace ("image_proj." , "" )] = f .get_tensor (key )
148+ elif key .startswith ("ip_adapter." ):
149+ state_dict ["ip_adapter" ][key .replace ("ip_adapter." , "" )] = f .get_tensor (key )
150+ else :
151+ state_dict = torch .load (model_file , map_location = "cpu" )
145152 else :
146- raise ValueError ("`image_encoder` cannot be None when using IP Adapters." )
147-
148- # create feature extractor if it has not been registered to the pipeline yet
149- if hasattr (self , "feature_extractor" ) and getattr (self , "feature_extractor" , None ) is None :
150- self .feature_extractor = CLIPImageProcessor ()
151- self .register_to_config (feature_extractor = ["transformers" , "CLIPImageProcessor" ])
152-
153- # load ip-adapter into unet
153+ state_dict = pretrained_model_name_or_path_or_dict
154+
155+ keys = list (state_dict .keys ())
156+ if keys != ["image_proj" , "ip_adapter" ]:
157+ raise ValueError ("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict." )
158+
159+ state_dicts .append (state_dict )
160+
161+ # load CLIP image encoder here if it has not been registered to the pipeline yet
162+ if hasattr (self , "image_encoder" ) and getattr (self , "image_encoder" , None ) is None :
163+ if not isinstance (pretrained_model_name_or_path_or_dict , dict ):
164+ logger .info (f"loading image_encoder from { pretrained_model_name_or_path_or_dict } " )
165+ image_encoder = CLIPVisionModelWithProjection .from_pretrained (
166+ pretrained_model_name_or_path_or_dict ,
167+ subfolder = Path (subfolder , "image_encoder" ).as_posix (),
168+ ).to (self .device , dtype = self .dtype )
169+ self .image_encoder = image_encoder
170+ self .register_to_config (image_encoder = ["transformers" , "CLIPVisionModelWithProjection" ])
171+ else :
172+ raise ValueError ("`image_encoder` cannot be None when using IP Adapters." )
173+
174+ # create feature extractor if it has not been registered to the pipeline yet
175+ if hasattr (self , "feature_extractor" ) and getattr (self , "feature_extractor" , None ) is None :
176+ feature_extractor = CLIPImageProcessor ()
177+ self .register_modules (feature_extractor = feature_extractor )
178+
179+ # load ip-adapter into unet
154180 unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
155- unet ._load_ip_adapter_weights (state_dict )
181+ unet ._load_ip_adapter_weights (state_dicts )
156182
157183 def set_ip_adapter_scale (self , scale ):
184+ if not isinstance (scale , list ):
185+ scale = [scale ]
158186 unet = getattr (self , self .unet_name ) if not hasattr (self , "unet" ) else self .unet
159187 for attn_processor in unet .attn_processors .values ():
160188 if isinstance (attn_processor , (IPAdapterAttnProcessor , IPAdapterAttnProcessor2_0 )):
0 commit comments