Skip to content

Commit c4791f9

Browse files
committed
[TRTLLM-6825][fix] Update lora for phi4-mm
Signed-off-by: Wanli Jiang <[email protected]>
1 parent 1e5a6be commit c4791f9

File tree

7 files changed

+34
-26
lines changed

7 files changed

+34
-26
lines changed

tensorrt_llm/_torch/models/modeling_phi4mm.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -243,23 +243,21 @@ def forward(
243243
@staticmethod
244244
def lora_config(model_dir: str):
245245
_lora_config = LoraConfig(
246-
lora_dir=[
247-
f"{model_dir}/vision-lora",
248-
f"{model_dir}/speech-lora",
249-
],
250246
lora_target_modules=[
251247
"attn_qkv",
252248
"attn_dense",
253-
"mlp_h_to_4h",
249+
"mlp_gate_up",
254250
"mlp_4h_to_h",
255251
],
256252
trtllm_modules_to_hf_modules={
257253
"attn_qkv": "qkv_proj",
258254
"attn_dense": "o_proj",
259-
"mlp_h_to_4h": "gate_up_proj",
255+
"mlp_gate_up": "gate_up_proj",
260256
"mlp_4h_to_h": "down_proj",
261257
},
262258
max_lora_rank=320, # Max rank for Phi4MM.
259+
swap_gate_up_proj_lora_b_weight=
260+
False, # Disable swap gate_up_proj.lora_B.weight for Phi4MM.
263261
)
264262
return _lora_config
265263

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,8 @@ def create_py_executor_instance(
509509
resources[ResourceManagerType.PEFT_CACHE_MANAGER] = peft_cache_manager
510510
model_engine.set_lora_model_config(
511511
lora_config.lora_target_modules,
512-
lora_config.trtllm_modules_to_hf_modules)
512+
lora_config.trtllm_modules_to_hf_modules,
513+
lora_config.swap_gate_up_proj_lora_b_weight)
513514

514515
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
515516

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,13 +437,16 @@ def __init__(
437437
else:
438438
self.cache_indirection_attention = None
439439

440-
def set_lora_model_config(self, lora_target_modules: list[str],
441-
trtllm_modules_to_hf_modules: dict[str, str]):
440+
def set_lora_model_config(self,
441+
lora_target_modules: list[str],
442+
trtllm_modules_to_hf_modules: dict[str, str],
443+
swap_gate_up_proj_lora_b_weight: bool = True):
442444
self.lora_model_config = LoraModelConfig(
443445
lora_target_modules=lora_target_modules,
444446
trtllm_modules_to_hf_modules=trtllm_modules_to_hf_modules,
445447
hidden_size=self.model.config.hidden_size,
446-
dtype=torch_dtype_to_str(self.model.config.torch_dtype))
448+
dtype=torch_dtype_to_str(self.model.config.torch_dtype),
449+
swap_gate_up_proj_lora_b_weight=swap_gate_up_proj_lora_b_weight)
447450

448451
@property
449452
def use_mrope(self):

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,8 @@ def __init__(self,
12061206
self._lora_model_config = LoraModelConfig(
12071207
lora_config.lora_target_modules,
12081208
lora_config.trtllm_modules_to_hf_modules, model_config.hidden_size,
1209-
binding_to_str_dtype(model_config.data_type))
1209+
binding_to_str_dtype(model_config.data_type),
1210+
lora_config.swap_gate_up_proj_lora_b_weight)
12101211
self._lora_manager = LoraManager()
12111212

12121213
def add_request_peft(self, request: LlmRequest):

tensorrt_llm/lora_manager.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ class LoraConfig(DictConversion):
241241
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
242242
max_loras: int | None = None
243243
max_cpu_loras: int | None = None
244+
swap_gate_up_proj_lora_b_weight: bool = True
244245

245246
def __post_init__(self):
246247
assert self.lora_ckpt_source in ["hf", "nemo"], (
@@ -258,6 +259,7 @@ class LoraModelConfig:
258259
trtllm_modules_to_hf_modules: dict[str, str]
259260
hidden_size: int
260261
dtype: str
262+
swap_gate_up_proj_lora_b_weight: bool = True
261263

262264

263265
class HfLoraLoader:
@@ -1026,16 +1028,17 @@ def load_from_hf(
10261028
)
10271029
hf_modules = set(hf_modules_to_trtllm_modules.keys())
10281030

1029-
def preprocess_lora_weights(lora_model):
1031+
def preprocess_lora_weights(lora_model, model_config):
10301032
# Swap weights of gate_up_proj
1031-
for key, value in lora_model.items():
1032-
if "gate_up_proj.lora_B.weight" in key:
1033-
original_weights = value.contiguous().clone()
1034-
half_split = original_weights.shape[0] // 2
1035-
first_half = original_weights[:half_split, :]
1036-
second_half = original_weights[half_split:, :]
1037-
value = torch.cat((second_half, first_half), dim=0)
1038-
lora_model[key] = value
1033+
if getattr(model_config, "swap_gate_up_proj_lora_b_weight", True):
1034+
for key, value in lora_model.items():
1035+
if "gate_up_proj.lora_B.weight" in key:
1036+
original_weights = value.contiguous().clone()
1037+
half_split = original_weights.shape[0] // 2
1038+
first_half = original_weights[:half_split, :]
1039+
second_half = original_weights[half_split:, :]
1040+
value = torch.cat((second_half, first_half), dim=0)
1041+
lora_model[key] = value
10391042
return lora_model
10401043

10411044
def load_from_model_dir(uid, model_dir, hf_config):
@@ -1047,7 +1050,7 @@ def load_from_model_dir(uid, model_dir, hf_config):
10471050
lora_model = load_state_dict(get_model_path(model_dir, "adapter_model"))
10481051
if lora_model is None:
10491052
raise ValueError(f"Failed to load adapter_model from {model_dir}")
1050-
lora_model = preprocess_lora_weights(lora_model)
1053+
lora_model = preprocess_lora_weights(lora_model, model_config)
10511054
all_weights = get_all_hf_lora_weights(lora_model, hf_modules, component)
10521055
rank = int(hf_config["r"])
10531056
rs_lora = bool(hf_config.get("use_rslora", False))

tests/integration/defs/perf/pytorch_model_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,17 @@ def get_model_yaml_config(model_label: str,
198198
}
199199
if 'phi_4_multimodal_instruct' in model_label:
200200
lora_config['lora_config']['lora_target_modules'] = [
201-
"attn_qkv", "attn_dense", "mlp_h_to_4h", "mlp_4h_to_h"
201+
"attn_qkv", "attn_dense", "mlp_gate_up", "mlp_4h_to_h"
202202
]
203203
lora_config['lora_config']['trtllm_modules_to_hf_modules'] = {
204204
"attn_qkv": "qkv_proj",
205205
"attn_dense": "o_proj",
206-
"mlp_h_to_4h": "gate_up_proj",
206+
"mlp_gate_up": "gate_up_proj",
207207
"mlp_4h_to_h": "down_proj"
208208
}
209209
lora_config['lora_config']['max_lora_rank'] = 320
210+
lora_config['lora_config'][
211+
'swap_gate_up_proj_lora_b_weight'] = False
210212
base_config.update(lora_config)
211213

212214
kv_cache_config = base_config.get('kv_cache_config', {})

tests/integration/defs/test_e2e.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2486,15 +2486,15 @@ def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality):
24862486
}
24872487
expected_keywords = {
24882488
"image": [
2489-
["image", "depicts", "mountain", "half", "rock"],
2490-
["road", "car", "lane", "traffic", "bus"],
2489+
["object", "mountain", "weather", "clear", "clouds"],
2490+
["traffic", "road", "vehicles", "cars", "bus"],
24912491
],
24922492
"audio": [
24932493
["what", "is", "the", "traffic", "sign", "in", "image"],
24942494
["what", "is", "shown", "in", "this", "image"],
24952495
],
24962496
"image_audio": [
2497-
["image", "depicts", "Grand", "rock", "scene"],
2497+
["image", "depicts", "scenic", "famous", "landmark"],
24982498
],
24992499
}
25002500

0 commit comments

Comments
 (0)