|
5 | 5 | import torch.nn.functional as F |
6 | 6 | from torch import nn |
7 | 7 |
|
| 8 | +from tensorrt_llm.logger import logger |
8 | 9 | from tensorrt_llm.mapping import Mapping |
9 | 10 |
|
10 | 11 | from ..distributed import AllReduceParams |
@@ -98,12 +99,21 @@ def __init__(self, |
98 | 99 | [LoraModuleType.MLP_GATE_UP], |
99 | 100 | [2 * self.intermediate_size // mapping.tp_size]) |
100 | 101 |
|
101 | | - def _apply_activation(self, x): |
| 102 | + def _apply_activation(self, x, *, has_lora: bool = False): |
102 | 103 | if self.activation == F.silu: |
103 | | - if self.down_proj.has_fp8_qdq or self.down_proj.has_w4a8_nvfp4_fp8: |
104 | | - return swiglu(x, |
105 | | - quant_scale=self.down_proj.input_scale, |
106 | | - quant_type=torch.float8_e4m3fn) |
| 104 | + if self.down_proj.has_fp8_qdq: |
| 105 | + if has_lora: |
| 106 | + # NOTE: This is a WAR, since LoRA grouped_gemm does not support FP8 yet. |
| 107 | + # TODO: Remove this path when LoRA grouped_gemm supports FP8 |
| 108 | + # see: cpp/tensorrt_llm/thop/loraOp.cpp::lora_grouped_gemm |
| 109 | + logger.warning( |
| 110 | + f"GatedMLP._apply_activation: LoRA path active; forcing non-FP8 activation dtype bf16/fp16, layer_idx={self.layer_idx}" |
| 111 | + ) |
| 112 | + return swiglu(x) |
| 113 | + else: |
| 114 | + return swiglu(x, |
| 115 | + quant_scale=self.down_proj.input_scale, |
| 116 | + quant_type=torch.float8_e4m3fn) |
107 | 117 | else: |
108 | 118 | return swiglu(x) |
109 | 119 | elif callable(self.activation): |
@@ -155,7 +165,7 @@ def forward_lora( |
155 | 165 | if h1_lora is not None: |
156 | 166 | h1 = h1 + h1_lora |
157 | 167 |
|
158 | | - h2 = self._apply_activation(h1) |
| 168 | + h2 = self._apply_activation(h1, has_lora=True) |
159 | 169 | output = self.down_proj(h2, |
160 | 170 | all_reduce_params=final_all_reduce_params, |
161 | 171 | lora_params=lora_params, |
|
0 commit comments