Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions tensorrt_llm/_torch/models/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,25 +158,29 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],

class Gemma3MLP(nn.Module):

def __init__(self, config: Gemma3TextConfig):
def __init__(self, model_config: ModelConfig[Gemma3TextConfig]):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.dtype = config.torch_dtype
self.config = model_config.pretrained_config
self.hidden_size = self.config.hidden_size
self.intermediate_size = self.config.intermediate_size
self.dtype = self.config.torch_dtype
self.quant_config = model_config.get_quant_config()
self.gate_proj = Linear(self.hidden_size,
self.intermediate_size,
bias=False,
dtype=self.dtype)
dtype=self.dtype,
quant_config=self.quant_config)
self.up_proj = Linear(self.hidden_size,
self.intermediate_size,
bias=False,
dtype=self.dtype)
dtype=self.dtype,
quant_config=self.quant_config)
self.down_proj = Linear(self.intermediate_size,
self.hidden_size,
bias=False,
dtype=self.dtype)
self.act_fn = ACT2FN[config.hidden_activation]
dtype=self.dtype,
quant_config=self.quant_config)
self.act_fn = ACT2FN[self.config.hidden_activation]

@torch.inference_mode()
def forward(self, x):
Expand All @@ -202,7 +206,7 @@ def __init__(
is_sliding=is_sliding,
)

self.mlp = Gemma3MLP(config)
self.mlp = Gemma3MLP(model_config=model_config)

self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/defs/accuracy/references/cnn_dailymail.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
google/gemma-3-1b-it:
- accuracy: 22.988
- quant_algo: FP8
kv_cache_quant_algo: FP8
accuracy: 22.988
google/gemma-3-27b-it:
- accuracy: 28.90
gpt2:
Expand Down
5 changes: 5 additions & 0 deletions tests/integration/defs/accuracy/references/mmlu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ mistralai/Mistral-Small-3.1-24B-Instruct-2503:
- accuracy: 81.7
google/gemma-2-9b-it:
- accuracy: 73.05
google/gemma-3-1b-it:
- accuracy: 39.0
- quant_algo: FP8
kv_cache_quant_algo: FP8
accuracy: 39.0
google/gemma-3-27b-it:
- accuracy: 77.80
Qwen/Qwen2-0.5B-Instruct:
Expand Down
16 changes: 16 additions & 0 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,22 @@ def test_auto_dtype(self):
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)

def test_fp8_prequantized(self):
# Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size.
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
enable_partial_reuse=False,
dtype="fp8")
prequantized_model_path = f"{llm_models_root()}/gemma/gemma-3-1b-it-fp8/"
with LLM(prequantized_model_path,
kv_cache_config=kv_cache_config) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
task = CnnDailymail(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)

def test_auto_dtype_vswa(self):
# NOTE: Test with VSWA kv cache config.
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=none-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_fp8_prequantized
- accuracy/test_llm_api_pytorch.py::TestGemma3_27BInstruct::test_auto_dtype
- accuracy/test_llm_api_pytorch.py::TestMistralSmall24B::test_auto_dtype
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]
Expand Down