@@ -241,6 +241,7 @@ class LoraConfig(DictConversion):
241241 trtllm_modules_to_hf_modules : Dict [str , str ] = field (default_factory = dict )
242242 max_loras : int | None = None
243243 max_cpu_loras : int | None = None
244+ swap_gate_up_proj_lora_b_weight : bool = True
244245
245246 def __post_init__ (self ):
246247 assert self .lora_ckpt_source in ["hf" , "nemo" ], (
@@ -258,6 +259,7 @@ class LoraModelConfig:
258259 trtllm_modules_to_hf_modules : dict [str , str ]
259260 hidden_size : int
260261 dtype : str
262+ swap_gate_up_proj_lora_b_weight : bool = True
261263
262264
263265class HfLoraLoader :
@@ -1026,16 +1028,17 @@ def load_from_hf(
10261028 )
10271029 hf_modules = set (hf_modules_to_trtllm_modules .keys ())
10281030
1029- def preprocess_lora_weights (lora_model ):
1031+ def preprocess_lora_weights (lora_model , model_config ):
10301032 # Swap weights of gate_up_proj
1031- for key , value in lora_model .items ():
1032- if "gate_up_proj.lora_B.weight" in key :
1033- original_weights = value .contiguous ().clone ()
1034- half_split = original_weights .shape [0 ] // 2
1035- first_half = original_weights [:half_split , :]
1036- second_half = original_weights [half_split :, :]
1037- value = torch .cat ((second_half , first_half ), dim = 0 )
1038- lora_model [key ] = value
1033+ if getattr (model_config , "swap_gate_up_proj_lora_b_weight" , True ):
1034+ for key , value in lora_model .items ():
1035+ if "gate_up_proj.lora_B.weight" in key :
1036+ original_weights = value .contiguous ().clone ()
1037+ half_split = original_weights .shape [0 ] // 2
1038+ first_half = original_weights [:half_split , :]
1039+ second_half = original_weights [half_split :, :]
1040+ value = torch .cat ((second_half , first_half ), dim = 0 )
1041+ lora_model [key ] = value
10391042 return lora_model
10401043
10411044 def load_from_model_dir (uid , model_dir , hf_config ):
@@ -1047,7 +1050,7 @@ def load_from_model_dir(uid, model_dir, hf_config):
10471050 lora_model = load_state_dict (get_model_path (model_dir , "adapter_model" ))
10481051 if lora_model is None :
10491052 raise ValueError (f"Failed to load adapter_model from { model_dir } " )
1050- lora_model = preprocess_lora_weights (lora_model )
1053+ lora_model = preprocess_lora_weights (lora_model , model_config )
10511054 all_weights = get_all_hf_lora_weights (lora_model , hf_modules , component )
10521055 rank = int (hf_config ["r" ])
10531056 rs_lora = bool (hf_config .get ("use_rslora" , False ))
0 commit comments