|
5 | 5 | import torch.nn.functional as F |
6 | 6 | from torch import nn |
7 | 7 |
|
| 8 | +from tensorrt_llm.logger import logger |
8 | 9 | from tensorrt_llm.mapping import Mapping |
9 | 10 |
|
10 | 11 | from ..distributed import AllReduceParams |
@@ -95,12 +96,23 @@ def __init__(self, |
95 | 96 | [LoraModuleType.MLP_GATE_UP], |
96 | 97 | [2 * self.intermediate_size // mapping.tp_size]) |
97 | 98 |
|
98 | | - def _apply_activation(self, x): |
| 99 | + def _apply_activation(self, x, *, for_lora: bool = False): |
99 | 100 | if self.activation == F.silu: |
100 | 101 | if self.down_proj.has_fp8_qdq: |
101 | | - return swiglu(x, |
102 | | - quant_scale=self.down_proj.input_scale, |
103 | | - quant_type=torch.float8_e4m3fn) |
| 102 | + if for_lora: |
| 103 | + |
| 104 | + target = torch.bfloat16 if torch.cuda.is_bf16_supported( |
| 105 | + ) else torch.float16 |
| 106 | + logger.debug( |
| 107 | + f"GatedMLP._apply_activation: LoRA path active; forcing non-FP8 activation dtype {target} (keeping activations in bf16/fp16), layer_idx={self.layer_idx}" |
| 108 | + ) |
| 109 | + return swiglu(x, |
| 110 | + quant_scale=self.down_proj.input_scale, |
| 111 | + quant_type=target) |
| 112 | + else: |
| 113 | + return swiglu(x, |
| 114 | + quant_scale=self.down_proj.input_scale, |
| 115 | + quant_type=torch.float8_e4m3fn) |
104 | 116 | else: |
105 | 117 | return swiglu(x) |
106 | 118 | elif callable(self.activation): |
@@ -152,7 +164,7 @@ def forward_lora( |
152 | 164 | if h1_lora is not None: |
153 | 165 | h1 = h1 + h1_lora |
154 | 166 |
|
155 | | - h2 = self._apply_activation(h1) |
| 167 | + h2 = self._apply_activation(h1, for_lora=True) |
156 | 168 | output = self.down_proj(h2, |
157 | 169 | all_reduce_params=final_all_reduce_params, |
158 | 170 | lora_params=lora_params, |
|
0 commit comments