Skip to content

Commit cbb4eb5

Browse files
committed
Move activation handling to GatedMLP for LoRA compatibility
- Modified _apply_activation method to accept a for_lora flag, allowing for specific handling of activation during LoRA operations. - Updated the call to _apply_activation in GatedMLP to pass the for_lora argument, ensuring correct behavior in LoRA scenarios. - Removed unnecessary dtype casting checks in LoraLayer, simplifying the code. Signed-off-by: Venky Ganesh <[email protected]>
1 parent 2a0ce00 commit cbb4eb5

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

tensorrt_llm/_torch/modules/gated_mlp.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch.nn.functional as F
66
from torch import nn
77

8+
from tensorrt_llm.logger import logger
89
from tensorrt_llm.mapping import Mapping
910

1011
from ..distributed import AllReduceParams
@@ -95,12 +96,23 @@ def __init__(self,
9596
[LoraModuleType.MLP_GATE_UP],
9697
[2 * self.intermediate_size // mapping.tp_size])
9798

98-
def _apply_activation(self, x):
99+
def _apply_activation(self, x, *, for_lora: bool = False):
99100
if self.activation == F.silu:
100101
if self.down_proj.has_fp8_qdq:
101-
return swiglu(x,
102-
quant_scale=self.down_proj.input_scale,
103-
quant_type=torch.float8_e4m3fn)
102+
if for_lora:
103+
104+
target = torch.bfloat16 if torch.cuda.is_bf16_supported(
105+
) else torch.float16
106+
logger.debug(
107+
f"GatedMLP._apply_activation: LoRA path active; forcing non-FP8 activation dtype {target} (keeping activations in bf16/fp16), layer_idx={self.layer_idx}"
108+
)
109+
return swiglu(x,
110+
quant_scale=self.down_proj.input_scale,
111+
quant_type=target)
112+
else:
113+
return swiglu(x,
114+
quant_scale=self.down_proj.input_scale,
115+
quant_type=torch.float8_e4m3fn)
104116
else:
105117
return swiglu(x)
106118
elif callable(self.activation):
@@ -152,7 +164,7 @@ def forward_lora(
152164
if h1_lora is not None:
153165
h1 = h1 + h1_lora
154166

155-
h2 = self._apply_activation(h1)
167+
h2 = self._apply_activation(h1, for_lora=True)
156168
output = self.down_proj(h2,
157169
all_reduce_params=final_all_reduce_params,
158170
lora_params=lora_params,

tensorrt_llm/_torch/peft/lora/layer.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33

44
import torch
55

6-
from tensorrt_llm._utils import logger
7-
86

97
class LoraModuleType(IntEnum):
108
"""Enum class representing different types of modules that can have LoRA adapters.
@@ -121,15 +119,6 @@ def forward(
121119
if len(active_lora_module_ids) == 0:
122120
return None
123121
else:
124-
# Guard: LoRA custom op only supports FP16/BF16 activations.
125-
# If upstream produced FP8 (e.g., FP8 SwiGLU), cast here to avoid runtime failure.
126-
if x.dtype not in (torch.float16, torch.bfloat16):
127-
target_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported(
128-
) else torch.float16
129-
logger.debug(
130-
f"lora_grouped_gemm supports only FP16/BF16. Casting input from {x.dtype} to {target_dtype}."
131-
)
132-
x = x.to(target_dtype).contiguous()
133122
lora_outputs = torch.ops.trtllm.lora_grouped_gemm(
134123
x,
135124
lora_params['host_request_types'][:num_seqs],

0 commit comments

Comments
 (0)