Skip to content

Commit 972b2d6

Browse files
committed
fix: Fix gibberish output with FP8 checkpoint
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 1f39a11 commit 972b2d6

File tree

6 files changed

+50
-10
lines changed

6 files changed

+50
-10
lines changed

tensorrt_llm/_torch/models/modeling_gemma3.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,25 +158,28 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
158158

159159
class Gemma3MLP(nn.Module):
160160

161-
def __init__(self, config: Gemma3TextConfig):
161+
def __init__(self, model_config: ModelConfig[Gemma3TextConfig]):
162162
super().__init__()
163-
self.config = config
164-
self.hidden_size = config.hidden_size
165-
self.intermediate_size = config.intermediate_size
166-
self.dtype = config.torch_dtype
163+
self.config = model_config.pretrained_config
164+
self.hidden_size = self.config.hidden_size
165+
self.intermediate_size = self.config.intermediate_size
166+
self.dtype = self.config.torch_dtype
167167
self.gate_proj = Linear(self.hidden_size,
168168
self.intermediate_size,
169169
bias=False,
170-
dtype=self.dtype)
170+
dtype=self.dtype,
171+
quant_config=model_config.get_quant_config())
171172
self.up_proj = Linear(self.hidden_size,
172173
self.intermediate_size,
173174
bias=False,
174-
dtype=self.dtype)
175+
dtype=self.dtype,
176+
quant_config=model_config.get_quant_config())
175177
self.down_proj = Linear(self.intermediate_size,
176178
self.hidden_size,
177179
bias=False,
178-
dtype=self.dtype)
179-
self.act_fn = ACT2FN[config.hidden_activation]
180+
dtype=self.dtype,
181+
quant_config=model_config.get_quant_config())
182+
self.act_fn = ACT2FN[self.config.hidden_activation]
180183

181184
@torch.inference_mode()
182185
def forward(self, x):
@@ -202,7 +205,7 @@ def __init__(
202205
is_sliding=is_sliding,
203206
)
204207

205-
self.mlp = Gemma3MLP(config)
208+
self.mlp = Gemma3MLP(model_config=model_config)
206209

207210
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
208211
eps=config.rms_norm_eps,

tests/integration/defs/accuracy/references/cnn_dailymail.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
google/gemma-3-1b-it:
22
- accuracy: 22.988
3+
- quant_algo: FP8
4+
kv_cache_quant_algo: FP8
5+
accuracy: 22.988
36
google/gemma-3-27b-it:
47
- accuracy: 28.90
8+
- quant_algo: FP8
9+
kv_cache_quant_algo: FP8
10+
accuracy: 28.90
511
gpt2:
612
- accuracy: 18.408
713
- quant_algo: W8A16

tests/integration/defs/accuracy/references/gsm8k.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,14 @@ speakleash/Bielik-11B-v2.2-Instruct:
115115
accuracy: 40.41
116116
google/gemma-3-1b-it:
117117
- accuracy: 25.52 # score getting from lm-eval with HF implementation
118+
- quant_algo: FP8
119+
kv_cache_quant_algo: FP8
120+
accuracy: 25.52
118121
google/gemma-3-27b-it:
119122
- accuracy: 91.66
123+
- quant_algo: FP8
124+
kv_cache_quant_algo: FP8
125+
accuracy: 91.66
120126
mistralai/Ministral-8B-Instruct-2410:
121127
- accuracy: 79.25
122128
- quant_algo: FP8

tests/integration/defs/accuracy/references/mmlu.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,16 @@ mistralai/Mistral-Small-3.1-24B-Instruct-2503:
9999
- accuracy: 81.7
100100
google/gemma-2-9b-it:
101101
- accuracy: 73.05
102+
google/gemma-3-1b-it:
103+
- accuracy: 39.0
104+
- quant_algo: FP8
105+
kv_cache_quant_algo: FP8
106+
accuracy: 39.0
102107
google/gemma-3-27b-it:
103108
- accuracy: 77.80
109+
- quant_algo: FP8
110+
kv_cache_quant_algo: FP8
111+
accuracy: 77.80
104112
Qwen/Qwen2-0.5B-Instruct:
105113
- accuracy: 45.30
106114
- quant_algo: FP8

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,22 @@ def test_auto_dtype(self):
604604
task.evaluate(llm)
605605
task = GSM8K(self.MODEL_NAME)
606606
task.evaluate(llm)
607+
task = MMLU(self.MODEL_NAME)
608+
task.evaluate(llm)
609+
610+
def test_fp8_prequantized(self):
611+
# Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size.
612+
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
613+
enable_partial_reuse=False,
614+
dtype="fp8")
615+
prequantized_model_path = "/home/bbuddharaju/scratch/random/hf_models/gemma-3-1b-it-fp8/"
616+
with LLM(prequantized_model_path,
617+
kv_cache_config=kv_cache_config) as llm:
618+
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
619+
task = CnnDailymail(self.MODEL_NAME)
620+
task.evaluate(llm)
621+
task = MMLU(self.MODEL_NAME)
622+
task.evaluate(llm)
607623

608624
def test_auto_dtype_vswa(self):
609625
# NOTE: Test with VSWA kv cache config.

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ l0_h100:
191191
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]
192192
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
193193
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=none-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
194+
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_prequantized
194195
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype
195196
- accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype
196197
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]

0 commit comments

Comments
 (0)