Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions tensorrt_llm/_torch/attention_backend/flashinfer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import os
import weakref
from dataclasses import dataclass, field
Expand Down Expand Up @@ -39,6 +40,8 @@ class PlanParams:

attention_mask_type: AttentionMaskType
attention_mask_data: Optional[torch.Tensor] = None
sm_scale: Optional[float] = None
window_left: Optional[int] = None


@dataclass(kw_only=True)
Expand Down Expand Up @@ -309,13 +312,23 @@ def plan(self,
q_dtype: torch.dtype,
kv_dtype: torch.dtype,
attention_mask_type: int,
q_scaling: Optional[float] = None,
attention_window_size: Optional[int] = None,
attention_mask_data: Optional[torch.Tensor] = None) -> PlanParams:

sm_scale = None
if q_scaling is not None:
sm_scale = 1 / (math.sqrt(head_dim) * q_scaling)

plan_params = PlanParams(
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
q_dtype=q_dtype,
kv_dtype=kv_dtype,
sm_scale=sm_scale,
window_left=attention_window_size
if attention_window_size is not None else -1,
attention_mask_type=AttentionMaskType(attention_mask_type),
attention_mask_data=attention_mask_data)
return self._plan_with_params(plan_params)
Expand Down Expand Up @@ -363,6 +376,8 @@ def prefill_plan():
plan_params.head_dim,
self.page_size,
causal=is_causal,
sm_scale=plan_params.sm_scale,
window_left=plan_params.window_left,
q_data_type=plan_params.q_dtype,
kv_data_type=plan_params.kv_dtype,
)
Expand Down Expand Up @@ -398,6 +413,8 @@ def decode_plan():
plan_params.num_kv_heads,
plan_params.head_dim,
self.page_size,
sm_scale=plan_params.sm_scale,
window_left=plan_params.window_left,
q_data_type=plan_params.q_dtype,
kv_data_type=plan_params.kv_dtype,
)
Expand Down Expand Up @@ -431,13 +448,15 @@ def __init__(
head_dim: int,
num_kv_heads: Optional[int] = None,
quant_config: Optional[QuantConfig] = None,
q_scaling: Optional[float] = None,
skip_create_weights_in_init: bool = False,
**kwargs,
):
super().__init__(layer_idx, num_heads, head_dim, num_kv_heads,
quant_config, **kwargs)
if not skip_create_weights_in_init:
self.update_quant_config(self.quant_config)
self.q_scaling = q_scaling

def update_quant_config(self, new_quant_config: Optional[QuantConfig]):
self.quant_config = new_quant_config
Expand All @@ -452,6 +471,7 @@ def forward(self,
v: Optional[torch.Tensor],
metadata: FlashInferAttentionMetadata,
*,
attention_window_size: Optional[int] = None,
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
**kwargs) -> torch.Tensor:
if attention_mask == PredefinedAttentionMask.CAUSAL:
Expand All @@ -463,10 +483,18 @@ def forward(self,
else:
raise ValueError("Unexpected attention mask type")

return forward_pattern(q, k, v, self.num_heads, self.head_dim,
self.num_kv_heads, self.layer_idx,
self.has_fp8_kv_cache, attention_mask_type,
attention_mask_data)
return forward_pattern(q=q,
k=k,
v=v,
num_heads=self.num_heads,
head_dim=self.head_dim,
num_kv_heads=self.num_kv_heads,
layer_idx=self.layer_idx,
has_fp8_kv_cache=self.has_fp8_kv_cache,
attention_mask_type=attention_mask_type,
q_scaling=self.q_scaling,
attention_mask_data=attention_mask_data,
attention_window_size=attention_window_size)


@torch.library.custom_op("trtllm::flashinfer_forward", mutates_args=())
Expand All @@ -480,7 +508,9 @@ def forward_pattern(
layer_idx: int,
has_fp8_kv_cache: bool,
attention_mask_type: int,
attention_mask_data: Optional[torch.Tensor],
q_scaling: Optional[float] = None,
attention_mask_data: Optional[torch.Tensor] = None,
attention_window_size: Optional[int] = None,
) -> torch.Tensor:
'''
Wrapping the flashinfer forward as a custom op is required to fix `torch.compile` graph breaks,
Expand Down Expand Up @@ -548,6 +578,8 @@ def decode_forward(plan_params: PlanParams):
head_dim,
q_dtype=q.dtype,
kv_dtype=kv_cache.dtype,
q_scaling=q_scaling,
attention_window_size=attention_window_size,
attention_mask_type=attention_mask_type,
attention_mask_data=attention_mask_data)

Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/models/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from transformers import Gemma3TextConfig
from transformers.activations import ACT2FN

from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.functional import PositionEmbeddingType, RotaryScalingType
from tensorrt_llm.mapping import Mapping

from ..attention_backend import AttentionMetadata
Expand Down Expand Up @@ -64,7 +64,9 @@ def __init__(
rope_params = RopeParams.from_config(config)
self.attention_window_size = None
if is_sliding:
rope_params.theta = 10000
rope_params.theta = config.rope_local_base_freq
rope_params.scale_type = RotaryScalingType.none
rope_params.scale = 1.0
self.attention_window_size = config.sliding_window - 1 # Gemma3 sliding window isn't inclusive.
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
Expand Down
82 changes: 71 additions & 11 deletions tests/unittest/_torch/modeling/test_modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping

# This is copied from https://huggingface.co/google/gemma-3-1b-it/blob/main/config.json.
# Updated to have 1 local layer and 1 global layer. Sliding window size updated to 4.
GEMMA3_1B_MINI_CONFIG = {
GEMMA3_1B_CONFIG = {
"architectures": ["Gemma3ForCausalLM"],
"attention_bias": False,
"attention_dropout": 0.0,
Expand All @@ -36,29 +34,68 @@
"max_position_embeddings": 32768,
"model_type": "gemma3_text",
"num_attention_heads": 4,
"num_hidden_layers": 2, # Modified for testing.
"num_hidden_layers": 26,
"num_key_value_heads": 1,
"pad_token_id": 0,
"query_pre_attn_scalar": 256,
"rms_norm_eps": 1e-06,
"rope_local_base_freq": 10000,
"rope_scaling": None,
"rope_theta": 1000000,
"sliding_window": 4, # Modified for testing.
"sliding_window_pattern": 2, # Modified for testing.
"sliding_window": 512,
"sliding_window_pattern": 6,
"torch_dtype": "bfloat16",
"transformers_version": "4.50.0.dev0",
"use_cache": True,
"vocab_size": 262144
}

GEMMA3_27B_CONFIG = {
"architectures": ["Gemma3ForConditionalGeneration"],
"boi_token_index": 255999,
"eoi_token_index": 256000,
"eos_token_id": [1, 106],
"image_token_index": 262144,
"initializer_range": 0.02,
"mm_tokens_per_image": 256,
"model_type": "gemma3",
"text_config": {
"head_dim": 128,
"hidden_size": 5376,
"intermediate_size": 21504,
"model_type": "gemma3_text",
"num_attention_heads": 32,
"num_hidden_layers": 62,
"num_key_value_heads": 16,
"query_pre_attn_scalar": 168,
"rope_scaling": {
"factor": 8.0,
"rope_type": "linear"
},
"sliding_window": 1024
},
"torch_dtype": "bfloat16",
"transformers_version": "4.50.0.dev0",
"vision_config": {
"hidden_size": 1152,
"image_size": 896,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
"vision_use_head": False
}
}


@dataclass(repr=False)
class Scenario:
backend: str
config_name: str

def __repr__(self) -> str:
return f"backend:{self.backend.lower()}"
return f"backend:{self.backend.lower()}_config:{self.config_name.lower()}"


class TestGemma3(unittest.TestCase):
Expand Down Expand Up @@ -95,7 +132,8 @@ def get_kv_cache_manager(self, dtype: torch.dtype, config: Gemma3Config,

def test_gemma3_sanity(self):

config_dict = deepcopy(GEMMA3_1B_MINI_CONFIG)
config_dict = deepcopy(
GEMMA3_1B_CONFIG) # Using 1B config for sanity test.
gemma3_config = Gemma3Config.from_dict(config_dict)

dtype = gemma3_config.torch_dtype
Expand Down Expand Up @@ -174,8 +212,12 @@ def test_gemma3_sanity(self):
kv_cache_manager.shutdown()

@parameterized.expand([
Scenario(backend="TRTLLM"),
Scenario(backend="VANILLA"),
Scenario(backend="TRTLLM", config_name="1B"),
Scenario(backend="VANILLA", config_name="1B"),
Scenario(backend="FLASHINFER", config_name="1B"),
Scenario(backend="TRTLLM", config_name="27B"),
Scenario(backend="VANILLA", config_name="27B"),
Scenario(backend="FLASHINFER", config_name="27B"),
], lambda testcase_func, param_num, param:
f"{testcase_func.__name__}[{param.args[0]}]")
@torch.no_grad()
Expand All @@ -184,14 +226,31 @@ def test_gemma3_allclose_to_hf(self, scenario: Scenario) -> None:
Compare output to HF.
"""
backend = scenario.backend
config_name = scenario.config_name
metadata_cls = get_attention_backend(backend).Metadata

torch.random.manual_seed(0)
config_dict = deepcopy(GEMMA3_1B_MINI_CONFIG)

# Select the appropriate config based on the scenario
if config_name == "1B":
config_dict = deepcopy(GEMMA3_1B_CONFIG)
elif config_name == "27B":
config_dict = deepcopy(GEMMA3_27B_CONFIG)
else:
raise ValueError(f"Unknown config_name: {config_name}")

gemma3_config = Gemma3Config.from_dict(config_dict)
if config_name == "27B":
gemma3_config.text_config.torch_dtype = gemma3_config.torch_dtype
gemma3_config = gemma3_config.text_config
dtype = gemma3_config.torch_dtype
device = torch.device('cuda')

# 2-layer network with one local (sliding window=4) and one global layer.
gemma3_config.num_hidden_layers = 2
gemma3_config.sliding_window = 4
gemma3_config.sliding_window_pattern = 2

num_blocks = 1
tokens_per_block = 128
max_seq_len = num_blocks * tokens_per_block
Expand Down Expand Up @@ -253,6 +312,7 @@ def test_gemma3_allclose_to_hf(self, scenario: Scenario) -> None:
position_ids=position_ids,
past_key_values=hf_cache,
use_cache=True)

torch.testing.assert_close(logits,
ref.logits[:, -1].float(),
atol=0.05,
Expand Down