diff --git a/tensorrt_llm/_torch/models/modeling_gemma3.py b/tensorrt_llm/_torch/models/modeling_gemma3.py index db672279646..727e45018e1 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3.py @@ -158,25 +158,29 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], class Gemma3MLP(nn.Module): - def __init__(self, config: Gemma3TextConfig): + def __init__(self, model_config: ModelConfig[Gemma3TextConfig]): super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.dtype = config.torch_dtype + self.config = model_config.pretrained_config + self.hidden_size = self.config.hidden_size + self.intermediate_size = self.config.intermediate_size + self.dtype = self.config.torch_dtype + self.quant_config = model_config.get_quant_config() self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias=False, - dtype=self.dtype) + dtype=self.dtype, + quant_config=self.quant_config) self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias=False, - dtype=self.dtype) + dtype=self.dtype, + quant_config=self.quant_config) self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias=False, - dtype=self.dtype) - self.act_fn = ACT2FN[config.hidden_activation] + dtype=self.dtype, + quant_config=self.quant_config) + self.act_fn = ACT2FN[self.config.hidden_activation] @torch.inference_mode() def forward(self, x): @@ -202,7 +206,7 @@ def __init__( is_sliding=is_sliding, ) - self.mlp = Gemma3MLP(config) + self.mlp = Gemma3MLP(model_config=model_config) self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, diff --git a/tests/integration/defs/accuracy/references/cnn_dailymail.yaml b/tests/integration/defs/accuracy/references/cnn_dailymail.yaml index 95bbc760477..fd4c43093fc 100644 --- a/tests/integration/defs/accuracy/references/cnn_dailymail.yaml +++ b/tests/integration/defs/accuracy/references/cnn_dailymail.yaml @@ -1,5 +1,8 @@ google/gemma-3-1b-it: - accuracy: 22.988 + - quant_algo: FP8 + kv_cache_quant_algo: FP8 + accuracy: 22.988 google/gemma-3-27b-it: - accuracy: 28.90 gpt2: diff --git a/tests/integration/defs/accuracy/references/mmlu.yaml b/tests/integration/defs/accuracy/references/mmlu.yaml index 2523418afcb..3d387f36b80 100644 --- a/tests/integration/defs/accuracy/references/mmlu.yaml +++ b/tests/integration/defs/accuracy/references/mmlu.yaml @@ -100,6 +100,11 @@ mistralai/Mistral-Small-3.1-24B-Instruct-2503: - accuracy: 81.7 google/gemma-2-9b-it: - accuracy: 73.05 +google/gemma-3-1b-it: + - accuracy: 39.0 + - quant_algo: FP8 + kv_cache_quant_algo: FP8 + accuracy: 39.0 google/gemma-3-27b-it: - accuracy: 77.80 Qwen/Qwen2-0.5B-Instruct: diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index e52cde99ad5..311cf96f4f6 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -604,6 +604,22 @@ def test_auto_dtype(self): task.evaluate(llm) task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + + def test_fp8_prequantized(self): + # Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size. + kv_cache_config = KvCacheConfig(enable_block_reuse=False, + enable_partial_reuse=False, + dtype="fp8") + prequantized_model_path = f"{llm_models_root()}/gemma/gemma-3-1b-it-fp8/" + with LLM(prequantized_model_path, + kv_cache_config=kv_cache_config) as llm: + assert llm.args.quant_config.quant_algo == QuantAlgo.FP8 + task = CnnDailymail(self.MODEL_NAME) + task.evaluate(llm) + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) def test_auto_dtype_vswa(self): # NOTE: Test with VSWA kv cache config. diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 957c6697c3a..ffbc06b79d1 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -191,6 +191,7 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=none-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True] + - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_prequantized - accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype - accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]