Skip to content

Commit 5a253c8

Browse files
venkywonkadominicshanshan
authored andcommitted
[https://nvbugs/5464088] [fix] dequantize fp8 activation input to lora forward; update perf test config (NVIDIA#7014)
Signed-off-by: Venky Ganesh <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
1 parent f26bad8 commit 5a253c8

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

tensorrt_llm/_torch/modules/gated_mlp.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch.nn.functional as F
66
from torch import nn
77

8+
from tensorrt_llm.logger import logger
89
from tensorrt_llm.mapping import Mapping
910

1011
from ..distributed import AllReduceParams
@@ -98,12 +99,21 @@ def __init__(self,
9899
[LoraModuleType.MLP_GATE_UP],
99100
[2 * self.intermediate_size // mapping.tp_size])
100101

101-
def _apply_activation(self, x):
102+
def _apply_activation(self, x, *, has_lora: bool = False):
102103
if self.activation == F.silu:
103-
if self.down_proj.has_fp8_qdq or self.down_proj.has_w4a8_nvfp4_fp8:
104-
return swiglu(x,
105-
quant_scale=self.down_proj.input_scale,
106-
quant_type=torch.float8_e4m3fn)
104+
if self.down_proj.has_fp8_qdq:
105+
if has_lora:
106+
# NOTE: This is a WAR, since LoRA grouped_gemm does not support FP8 yet.
107+
# TODO: Remove this path when LoRA grouped_gemm supports FP8
108+
# see: cpp/tensorrt_llm/thop/loraOp.cpp::lora_grouped_gemm
109+
logger.warning(
110+
f"GatedMLP._apply_activation: LoRA path active; forcing non-FP8 activation dtype bf16/fp16, layer_idx={self.layer_idx}"
111+
)
112+
return swiglu(x)
113+
else:
114+
return swiglu(x,
115+
quant_scale=self.down_proj.input_scale,
116+
quant_type=torch.float8_e4m3fn)
107117
else:
108118
return swiglu(x)
109119
elif callable(self.activation):
@@ -155,7 +165,7 @@ def forward_lora(
155165
if h1_lora is not None:
156166
h1 = h1 + h1_lora
157167

158-
h2 = self._apply_activation(h1)
168+
h2 = self._apply_activation(h1, has_lora=True)
159169
output = self.down_proj(h2,
160170
all_reduce_params=final_all_reduce_params,
161171
lora_params=lora_params,

tests/integration/defs/perf/pytorch_model_config.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,19 @@ def get_model_yaml_config(model_label: str,
181181

182182
# lora-specific change for pytorch
183183
if 'pytorch' in model_label and 'loras' in model_label:
184+
# Derive the requested number of adapters from model_label (segment like "loras:X")
185+
lora_count = 1
186+
for part in model_label.split('-'):
187+
if part.startswith('loras:'):
188+
lora_count = max(1, int(part.split(':', 1)[1]))
189+
break
190+
184191
lora_config = {
185192
'lora_config': {
186193
'lora_dir': lora_dirs if lora_dirs is not None else [],
187-
'max_lora_rank': 64
194+
'max_lora_rank': 64,
195+
'max_loras': lora_count,
196+
'max_cpu_loras': lora_count,
188197
}
189198
}
190199
if 'phi_4_multimodal_instruct' in model_label:

0 commit comments

Comments
 (0)