diff --git a/tensorrt_llm/_torch/attention_backend/flashinfer.py b/tensorrt_llm/_torch/attention_backend/flashinfer.py index e7f1b2a3109..5c6327ebbd0 100644 --- a/tensorrt_llm/_torch/attention_backend/flashinfer.py +++ b/tensorrt_llm/_torch/attention_backend/flashinfer.py @@ -1,3 +1,4 @@ +import math import os import weakref from dataclasses import dataclass, field @@ -39,6 +40,8 @@ class PlanParams: attention_mask_type: AttentionMaskType attention_mask_data: Optional[torch.Tensor] = None + sm_scale: Optional[float] = None + window_left: Optional[int] = None @dataclass(kw_only=True) @@ -309,13 +312,23 @@ def plan(self, q_dtype: torch.dtype, kv_dtype: torch.dtype, attention_mask_type: int, + q_scaling: Optional[float] = None, + attention_window_size: Optional[int] = None, attention_mask_data: Optional[torch.Tensor] = None) -> PlanParams: + + sm_scale = None + if q_scaling is not None: + sm_scale = 1 / (math.sqrt(head_dim) * q_scaling) + plan_params = PlanParams( num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, q_dtype=q_dtype, kv_dtype=kv_dtype, + sm_scale=sm_scale, + window_left=attention_window_size + if attention_window_size is not None else -1, attention_mask_type=AttentionMaskType(attention_mask_type), attention_mask_data=attention_mask_data) return self._plan_with_params(plan_params) @@ -363,6 +376,8 @@ def prefill_plan(): plan_params.head_dim, self.page_size, causal=is_causal, + sm_scale=plan_params.sm_scale, + window_left=plan_params.window_left, q_data_type=plan_params.q_dtype, kv_data_type=plan_params.kv_dtype, ) @@ -398,6 +413,8 @@ def decode_plan(): plan_params.num_kv_heads, plan_params.head_dim, self.page_size, + sm_scale=plan_params.sm_scale, + window_left=plan_params.window_left, q_data_type=plan_params.q_dtype, kv_data_type=plan_params.kv_dtype, ) @@ -431,6 +448,7 @@ def __init__( head_dim: int, num_kv_heads: Optional[int] = None, quant_config: Optional[QuantConfig] = None, + q_scaling: Optional[float] = None, skip_create_weights_in_init: bool = False, **kwargs, ): @@ -438,6 +456,7 @@ def __init__( quant_config, **kwargs) if not skip_create_weights_in_init: self.update_quant_config(self.quant_config) + self.q_scaling = q_scaling def update_quant_config(self, new_quant_config: Optional[QuantConfig]): self.quant_config = new_quant_config @@ -452,6 +471,7 @@ def forward(self, v: Optional[torch.Tensor], metadata: FlashInferAttentionMetadata, *, + attention_window_size: Optional[int] = None, attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL, **kwargs) -> torch.Tensor: if attention_mask == PredefinedAttentionMask.CAUSAL: @@ -463,10 +483,18 @@ def forward(self, else: raise ValueError("Unexpected attention mask type") - return forward_pattern(q, k, v, self.num_heads, self.head_dim, - self.num_kv_heads, self.layer_idx, - self.has_fp8_kv_cache, attention_mask_type, - attention_mask_data) + return forward_pattern(q=q, + k=k, + v=v, + num_heads=self.num_heads, + head_dim=self.head_dim, + num_kv_heads=self.num_kv_heads, + layer_idx=self.layer_idx, + has_fp8_kv_cache=self.has_fp8_kv_cache, + attention_mask_type=attention_mask_type, + q_scaling=self.q_scaling, + attention_mask_data=attention_mask_data, + attention_window_size=attention_window_size) @torch.library.custom_op("trtllm::flashinfer_forward", mutates_args=()) @@ -480,7 +508,9 @@ def forward_pattern( layer_idx: int, has_fp8_kv_cache: bool, attention_mask_type: int, - attention_mask_data: Optional[torch.Tensor], + q_scaling: Optional[float] = None, + attention_mask_data: Optional[torch.Tensor] = None, + attention_window_size: Optional[int] = None, ) -> torch.Tensor: ''' Wrapping the flashinfer forward as a custom op is required to fix `torch.compile` graph breaks, @@ -548,6 +578,8 @@ def decode_forward(plan_params: PlanParams): head_dim, q_dtype=q.dtype, kv_dtype=kv_cache.dtype, + q_scaling=q_scaling, + attention_window_size=attention_window_size, attention_mask_type=attention_mask_type, attention_mask_data=attention_mask_data) diff --git a/tensorrt_llm/_torch/models/modeling_gemma3.py b/tensorrt_llm/_torch/models/modeling_gemma3.py index 6e2a5dfdabf..9d633d3d88a 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3.py @@ -7,7 +7,7 @@ from transformers import Gemma3TextConfig from transformers.activations import ACT2FN -from tensorrt_llm.functional import PositionEmbeddingType +from tensorrt_llm.functional import PositionEmbeddingType, RotaryScalingType from tensorrt_llm.mapping import Mapping from ..attention_backend import AttentionMetadata @@ -64,7 +64,9 @@ def __init__( rope_params = RopeParams.from_config(config) self.attention_window_size = None if is_sliding: - rope_params.theta = 10000 + rope_params.theta = config.rope_local_base_freq + rope_params.scale_type = RotaryScalingType.none + rope_params.scale = 1.0 self.attention_window_size = config.sliding_window - 1 # Gemma3 sliding window isn't inclusive. pos_embd_params = PositionalEmbeddingParams( type=PositionEmbeddingType.rope_gpt_neox, diff --git a/tests/unittest/_torch/modeling/test_modeling_gemma3.py b/tests/unittest/_torch/modeling/test_modeling_gemma3.py index 03f4d7c8c1d..b58dd64cb4f 100644 --- a/tests/unittest/_torch/modeling/test_modeling_gemma3.py +++ b/tests/unittest/_torch/modeling/test_modeling_gemma3.py @@ -17,9 +17,7 @@ from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping -# This is copied from https://huggingface.co/google/gemma-3-1b-it/blob/main/config.json. -# Updated to have 1 local layer and 1 global layer. Sliding window size updated to 4. -GEMMA3_1B_MINI_CONFIG = { +GEMMA3_1B_CONFIG = { "architectures": ["Gemma3ForCausalLM"], "attention_bias": False, "attention_dropout": 0.0, @@ -36,7 +34,7 @@ "max_position_embeddings": 32768, "model_type": "gemma3_text", "num_attention_heads": 4, - "num_hidden_layers": 2, # Modified for testing. + "num_hidden_layers": 26, "num_key_value_heads": 1, "pad_token_id": 0, "query_pre_attn_scalar": 256, @@ -44,21 +42,60 @@ "rope_local_base_freq": 10000, "rope_scaling": None, "rope_theta": 1000000, - "sliding_window": 4, # Modified for testing. - "sliding_window_pattern": 2, # Modified for testing. + "sliding_window": 512, + "sliding_window_pattern": 6, "torch_dtype": "bfloat16", "transformers_version": "4.50.0.dev0", "use_cache": True, "vocab_size": 262144 } +GEMMA3_27B_CONFIG = { + "architectures": ["Gemma3ForConditionalGeneration"], + "boi_token_index": 255999, + "eoi_token_index": 256000, + "eos_token_id": [1, 106], + "image_token_index": 262144, + "initializer_range": 0.02, + "mm_tokens_per_image": 256, + "model_type": "gemma3", + "text_config": { + "head_dim": 128, + "hidden_size": 5376, + "intermediate_size": 21504, + "model_type": "gemma3_text", + "num_attention_heads": 32, + "num_hidden_layers": 62, + "num_key_value_heads": 16, + "query_pre_attn_scalar": 168, + "rope_scaling": { + "factor": 8.0, + "rope_type": "linear" + }, + "sliding_window": 1024 + }, + "torch_dtype": "bfloat16", + "transformers_version": "4.50.0.dev0", + "vision_config": { + "hidden_size": 1152, + "image_size": 896, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 14, + "vision_use_head": False + } +} + @dataclass(repr=False) class Scenario: backend: str + config_name: str def __repr__(self) -> str: - return f"backend:{self.backend.lower()}" + return f"backend:{self.backend.lower()}_config:{self.config_name.lower()}" class TestGemma3(unittest.TestCase): @@ -95,7 +132,8 @@ def get_kv_cache_manager(self, dtype: torch.dtype, config: Gemma3Config, def test_gemma3_sanity(self): - config_dict = deepcopy(GEMMA3_1B_MINI_CONFIG) + config_dict = deepcopy( + GEMMA3_1B_CONFIG) # Using 1B config for sanity test. gemma3_config = Gemma3Config.from_dict(config_dict) dtype = gemma3_config.torch_dtype @@ -174,8 +212,12 @@ def test_gemma3_sanity(self): kv_cache_manager.shutdown() @parameterized.expand([ - Scenario(backend="TRTLLM"), - Scenario(backend="VANILLA"), + Scenario(backend="TRTLLM", config_name="1B"), + Scenario(backend="VANILLA", config_name="1B"), + Scenario(backend="FLASHINFER", config_name="1B"), + Scenario(backend="TRTLLM", config_name="27B"), + Scenario(backend="VANILLA", config_name="27B"), + Scenario(backend="FLASHINFER", config_name="27B"), ], lambda testcase_func, param_num, param: f"{testcase_func.__name__}[{param.args[0]}]") @torch.no_grad() @@ -184,14 +226,31 @@ def test_gemma3_allclose_to_hf(self, scenario: Scenario) -> None: Compare output to HF. """ backend = scenario.backend + config_name = scenario.config_name metadata_cls = get_attention_backend(backend).Metadata torch.random.manual_seed(0) - config_dict = deepcopy(GEMMA3_1B_MINI_CONFIG) + + # Select the appropriate config based on the scenario + if config_name == "1B": + config_dict = deepcopy(GEMMA3_1B_CONFIG) + elif config_name == "27B": + config_dict = deepcopy(GEMMA3_27B_CONFIG) + else: + raise ValueError(f"Unknown config_name: {config_name}") + gemma3_config = Gemma3Config.from_dict(config_dict) + if config_name == "27B": + gemma3_config.text_config.torch_dtype = gemma3_config.torch_dtype + gemma3_config = gemma3_config.text_config dtype = gemma3_config.torch_dtype device = torch.device('cuda') + # 2-layer network with one local (sliding window=4) and one global layer. + gemma3_config.num_hidden_layers = 2 + gemma3_config.sliding_window = 4 + gemma3_config.sliding_window_pattern = 2 + num_blocks = 1 tokens_per_block = 128 max_seq_len = num_blocks * tokens_per_block @@ -253,6 +312,7 @@ def test_gemma3_allclose_to_hf(self, scenario: Scenario) -> None: position_ids=position_ids, past_key_values=hf_cache, use_cache=True) + torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.05,