Skip to content

Commit 0c2d032

Browse files
brb-nvlancelly
authored andcommitted
fix: Fix poor generation with FP8 Gemma3 1B checkpoint (NVIDIA#6499)
Signed-off-by: Balaram Buddharaju <[email protected]> Signed-off-by: Lanyu Liao <[email protected]>
1 parent 55cd214 commit 0c2d032

File tree

5 files changed

+39
-10
lines changed

5 files changed

+39
-10
lines changed

tensorrt_llm/_torch/models/modeling_gemma3.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,25 +158,29 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
158158

159159
class Gemma3MLP(nn.Module):
160160

161-
def __init__(self, config: Gemma3TextConfig):
161+
def __init__(self, model_config: ModelConfig[Gemma3TextConfig]):
162162
super().__init__()
163-
self.config = config
164-
self.hidden_size = config.hidden_size
165-
self.intermediate_size = config.intermediate_size
166-
self.dtype = config.torch_dtype
163+
self.config = model_config.pretrained_config
164+
self.hidden_size = self.config.hidden_size
165+
self.intermediate_size = self.config.intermediate_size
166+
self.dtype = self.config.torch_dtype
167+
self.quant_config = model_config.get_quant_config()
167168
self.gate_proj = Linear(self.hidden_size,
168169
self.intermediate_size,
169170
bias=False,
170-
dtype=self.dtype)
171+
dtype=self.dtype,
172+
quant_config=self.quant_config)
171173
self.up_proj = Linear(self.hidden_size,
172174
self.intermediate_size,
173175
bias=False,
174-
dtype=self.dtype)
176+
dtype=self.dtype,
177+
quant_config=self.quant_config)
175178
self.down_proj = Linear(self.intermediate_size,
176179
self.hidden_size,
177180
bias=False,
178-
dtype=self.dtype)
179-
self.act_fn = ACT2FN[config.hidden_activation]
181+
dtype=self.dtype,
182+
quant_config=self.quant_config)
183+
self.act_fn = ACT2FN[self.config.hidden_activation]
180184

181185
@torch.inference_mode()
182186
def forward(self, x):
@@ -202,7 +206,7 @@ def __init__(
202206
is_sliding=is_sliding,
203207
)
204208

205-
self.mlp = Gemma3MLP(config)
209+
self.mlp = Gemma3MLP(model_config=model_config)
206210

207211
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
208212
eps=config.rms_norm_eps,

tests/integration/defs/accuracy/references/cnn_dailymail.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
google/gemma-3-1b-it:
22
- accuracy: 22.988
3+
- quant_algo: FP8
4+
kv_cache_quant_algo: FP8
5+
accuracy: 22.988
36
google/gemma-3-27b-it:
47
- accuracy: 28.90
58
gpt2:

tests/integration/defs/accuracy/references/mmlu.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ mistralai/Mistral-Small-3.1-24B-Instruct-2503:
100100
- accuracy: 81.7
101101
google/gemma-2-9b-it:
102102
- accuracy: 73.05
103+
google/gemma-3-1b-it:
104+
- accuracy: 39.0
105+
- quant_algo: FP8
106+
kv_cache_quant_algo: FP8
107+
accuracy: 39.0
103108
google/gemma-3-27b-it:
104109
- accuracy: 77.80
105110
Qwen/Qwen2-0.5B-Instruct:

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,22 @@ def test_auto_dtype(self):
604604
task.evaluate(llm)
605605
task = GSM8K(self.MODEL_NAME)
606606
task.evaluate(llm)
607+
task = MMLU(self.MODEL_NAME)
608+
task.evaluate(llm)
609+
610+
def test_fp8_prequantized(self):
611+
# Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size.
612+
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
613+
enable_partial_reuse=False,
614+
dtype="fp8")
615+
prequantized_model_path = f"{llm_models_root()}/gemma/gemma-3-1b-it-fp8/"
616+
with LLM(prequantized_model_path,
617+
kv_cache_config=kv_cache_config) as llm:
618+
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
619+
task = CnnDailymail(self.MODEL_NAME)
620+
task.evaluate(llm)
621+
task = MMLU(self.MODEL_NAME)
622+
task.evaluate(llm)
607623

608624
def test_auto_dtype_vswa(self):
609625
# NOTE: Test with VSWA kv cache config.

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ l0_h100:
191191
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]
192192
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
193193
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=none-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
194+
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_prequantized
194195
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype
195196
- accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype
196197
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]

0 commit comments

Comments
 (0)