From c4d18f080241cdf25b26b2b73bd158d4b76a0fc6 Mon Sep 17 00:00:00 2001 From: Shiv Ghai <8965168+shivghai@users.noreply.github.com> Date: Fri, 12 Dec 2025 12:44:01 -0500 Subject: [PATCH 1/5] [Gemma3] Fix RoPE for local attention for Gemma3 Signed-off-by: Shiv Ghai <8965168+shivghai@users.noreply.github.com> --- tensorrt_llm/layers/attention.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tensorrt_llm/layers/attention.py b/tensorrt_llm/layers/attention.py index f995b6390d3..acdf703334e 100755 --- a/tensorrt_llm/layers/attention.py +++ b/tensorrt_llm/layers/attention.py @@ -702,16 +702,20 @@ def create_attention_const_params(model_cls, config): is_buffer=True)) else: - def register_rope_params(rotary_base, names_to_register): + def register_rope_params(rotary_base, names_to_register, is_local=False): # Rotary const weights. embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions( max_position_embeddings, rotary_embedding_dim, ) + # For local attention, use no scaling (consistent with forward pass) + local_scale = 1.0 if is_local else rotary_embedding_scale + local_scale_type = RotaryScalingType.none if is_local else rotary_embedding_scale_type + local_scaling = None if is_local else rotary_embedding_scaling + rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin( max_position_embeddings, rotary_embedding_dim, rotary_base, - rotary_embedding_scale, rotary_embedding_scale_type, - rotary_embedding_scaling) + local_scale, local_scale_type, local_scaling) model_cls.register_parameter( names_to_register[0], Parameter(embed_positions, dtype='float32', is_buffer=True)) @@ -739,7 +743,8 @@ def register_rope_params(rotary_base, names_to_register): names_to_register=[ 'embed_positions_local', 'rotary_inv_freq_local', 'embed_positions_for_gpt_attention_local' - ]) + ], + is_local=True) @staticmethod def fill_attention_params(model_cls, attention_params): @@ -1141,10 +1146,10 @@ def compute_cross_kv(encoder_output): rotary_embedding_dim=self.rotary_embedding_dim, rotary_embedding_base=self.rotary_embedding_base if not self.is_local else self.rotary_embedding_base_local, - rotary_embedding_scale_type=self.rotary_embedding_scale_type, + rotary_embedding_scale_type=self.rotary_embedding_scale_type if not self.is_local else RotaryScalingType.none, rotary_embedding_short_m_scale=attention_params.short_mscale, rotary_embedding_long_m_scale=attention_params.long_mscale, - rotary_embedding_scale=self.rotary_embedding_scale, + rotary_embedding_scale=self.rotary_embedding_scale if not self.is_local else 1.0, rotary_embedding_max_positions=self.max_position_embeddings, rotary_embedding_original_max_positions=self. original_max_position_embeddings, @@ -2792,4 +2797,4 @@ def forward(self, attention_mask=attention_mask, max_input_length=max_input_length, *args, - **kwargs) + **kwargs) \ No newline at end of file From c8b18cf8c144ff0e2e8f341c56af9858b2db6ceb Mon Sep 17 00:00:00 2001 From: Shiv Ghai <8965168+shivghai@users.noreply.github.com> Date: Fri, 12 Dec 2025 13:17:11 -0500 Subject: [PATCH 2/5] update Signed-off-by: Shiv Ghai <8965168+shivghai@users.noreply.github.com> --- tests/unittest/others/test_layer.py | 108 ++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/tests/unittest/others/test_layer.py b/tests/unittest/others/test_layer.py index 8de1189dc90..0c399167451 100644 --- a/tests/unittest/others/test_layer.py +++ b/tests/unittest/others/test_layer.py @@ -2116,5 +2116,113 @@ def fuse_rg_lru(recurrent_layer): rtol=rtol) + def test_gemma3_local_attention_rope_scaling(self): + """ + Test that local attention layers in Gemma3 do NOT apply rope scaling, + even when the config has rope_scaling defined. + + This is important for Gemma3 which uses different RoPE parameters for + local (sliding window) attention vs global attention layers. The fix + ensures that local attention layers get scale=1.0 and scale_type=none, + while global layers get the configured scaling. + """ + from tensorrt_llm.functional import (PositionEmbeddingType, + RotaryScalingType) + from tensorrt_llm.layers.attention import Attention + + # Create a mock config similar to Gemma3 27B with rope_scaling + class MockGemma3Config: + hidden_size = 5376 + num_attention_heads = 32 + head_size = 128 + max_position_embeddings = 32768 + position_embedding_type = PositionEmbeddingType.rope_gpt_neox + rotary_base = 1000000.0 + rotary_scaling = { + "factor": 8.0, + "rope_type": "linear" + } + rotary_pct = 1.0 + # Local attention uses a different base frequency + rope_local_base_freq = 10000.0 + + # Create a mock model class to receive registered parameters + class MockModelCls: + position_embedding_type = PositionEmbeddingType.rope_gpt_neox + + @classmethod + def register_parameter(cls, name, param): + setattr(cls, name, param) + + config = MockGemma3Config() + + # Call the method that creates attention const params + Attention.create_attention_const_params(MockModelCls, config) + + # Verify that global rope parameters are registered + self.assertTrue(hasattr(MockModelCls, 'embed_positions'), + "Global embed_positions should be registered") + self.assertTrue(hasattr(MockModelCls, 'rotary_inv_freq'), + "Global rotary_inv_freq should be registered") + self.assertTrue( + hasattr(MockModelCls, 'embed_positions_for_gpt_attention'), + "Global embed_positions_for_gpt_attention should be registered") + + # Verify that local rope parameters are registered (since rope_local_base_freq is set) + self.assertTrue(hasattr(MockModelCls, 'embed_positions_local'), + "Local embed_positions should be registered") + self.assertTrue(hasattr(MockModelCls, 'rotary_inv_freq_local'), + "Local rotary_inv_freq should be registered") + self.assertTrue( + hasattr(MockModelCls, 'embed_positions_for_gpt_attention_local'), + "Local embed_positions_for_gpt_attention should be registered") + + # Get the parameter values + global_inv_freq = MockModelCls.rotary_inv_freq.raw_value + local_inv_freq = MockModelCls.rotary_inv_freq_local.raw_value + global_cos_sin = MockModelCls.embed_positions_for_gpt_attention.raw_value + local_cos_sin = MockModelCls.embed_positions_for_gpt_attention_local.raw_value + + # The global and local inv_freq should be different because: + # 1. Global uses rope_scaling with factor=8.0 (linear scaling applies 1/8 to inv_freq) + # 2. Local uses scale=1.0 (no scaling) + # Also they use different base frequencies (1000000 vs 10000) + self.assertFalse( + np.allclose(global_inv_freq, local_inv_freq), + "Global and local rotary_inv_freq should be different " + "(global has scaling, local does not)") + + # The cos/sin embeddings should also be different + self.assertFalse( + np.allclose(global_cos_sin, local_cos_sin), + "Global and local embed_positions_for_gpt_attention should be different " + "(global has scaling, local does not)") + + # Additional verification: Check that local inv_freq matches unscaled calculation + # For local attention with scale=1.0 and base=10000: + # inv_freq = 1.0 / (10000 ** (arange(0, dim, 2) / dim)) + dim = config.head_size # rotary_embedding_dim = head_size * rotary_pct = 128 + expected_local_inv_freq = 1.0 / (config.rope_local_base_freq**( + np.arange(0, dim, 2) / dim)) + + np.testing.assert_allclose( + local_inv_freq, + expected_local_inv_freq, + rtol=1e-5, + err_msg="Local rotary_inv_freq should be computed WITHOUT scaling") + + # For global attention with linear scaling (factor=8.0): + # scale = 1.0 / 8.0 = 0.125 + # inv_freq = 0.125 / (1000000 ** (arange(0, dim, 2) / dim)) + expected_global_inv_freq = (1.0 / 8.0) / (config.rotary_base**( + np.arange(0, dim, 2) / dim)) + + np.testing.assert_allclose( + global_inv_freq, + expected_global_inv_freq, + rtol=1e-5, + err_msg="Global rotary_inv_freq should be computed WITH linear scaling") + + if __name__ == '__main__': unittest.main() From dcfb80acf9e5f5b5740b33843333db687cba22a8 Mon Sep 17 00:00:00 2001 From: Shiv Ghai <8965168+shivghai@users.noreply.github.com> Date: Wed, 17 Dec 2025 19:13:36 +0000 Subject: [PATCH 3/5] lint/fmt using pre-commit Signed-off-by: Shiv Ghai <8965168+shivghai@users.noreply.github.com> --- tensorrt_llm/layers/attention.py | 12 ++++++++---- tests/unittest/others/test_layer.py | 20 ++++++++------------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tensorrt_llm/layers/attention.py b/tensorrt_llm/layers/attention.py index acdf703334e..eb3fa703f64 100755 --- a/tensorrt_llm/layers/attention.py +++ b/tensorrt_llm/layers/attention.py @@ -702,7 +702,9 @@ def create_attention_const_params(model_cls, config): is_buffer=True)) else: - def register_rope_params(rotary_base, names_to_register, is_local=False): + def register_rope_params(rotary_base, + names_to_register, + is_local=False): # Rotary const weights. embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions( max_position_embeddings, @@ -1146,10 +1148,12 @@ def compute_cross_kv(encoder_output): rotary_embedding_dim=self.rotary_embedding_dim, rotary_embedding_base=self.rotary_embedding_base if not self.is_local else self.rotary_embedding_base_local, - rotary_embedding_scale_type=self.rotary_embedding_scale_type if not self.is_local else RotaryScalingType.none, + rotary_embedding_scale_type=self.rotary_embedding_scale_type + if not self.is_local else RotaryScalingType.none, rotary_embedding_short_m_scale=attention_params.short_mscale, rotary_embedding_long_m_scale=attention_params.long_mscale, - rotary_embedding_scale=self.rotary_embedding_scale if not self.is_local else 1.0, + rotary_embedding_scale=self.rotary_embedding_scale + if not self.is_local else 1.0, rotary_embedding_max_positions=self.max_position_embeddings, rotary_embedding_original_max_positions=self. original_max_position_embeddings, @@ -2797,4 +2801,4 @@ def forward(self, attention_mask=attention_mask, max_input_length=max_input_length, *args, - **kwargs) \ No newline at end of file + **kwargs) diff --git a/tests/unittest/others/test_layer.py b/tests/unittest/others/test_layer.py index 0c399167451..7afcb28eded 100644 --- a/tests/unittest/others/test_layer.py +++ b/tests/unittest/others/test_layer.py @@ -2115,7 +2115,6 @@ def fuse_rg_lru(recurrent_layer): atol=atol, rtol=rtol) - def test_gemma3_local_attention_rope_scaling(self): """ Test that local attention layers in Gemma3 do NOT apply rope scaling, @@ -2126,8 +2125,7 @@ def test_gemma3_local_attention_rope_scaling(self): ensures that local attention layers get scale=1.0 and scale_type=none, while global layers get the configured scaling. """ - from tensorrt_llm.functional import (PositionEmbeddingType, - RotaryScalingType) + from tensorrt_llm.functional import PositionEmbeddingType from tensorrt_llm.layers.attention import Attention # Create a mock config similar to Gemma3 27B with rope_scaling @@ -2138,10 +2136,7 @@ class MockGemma3Config: max_position_embeddings = 32768 position_embedding_type = PositionEmbeddingType.rope_gpt_neox rotary_base = 1000000.0 - rotary_scaling = { - "factor": 8.0, - "rope_type": "linear" - } + rotary_scaling = {"factor": 8.0, "rope_type": "linear"} rotary_pct = 1.0 # Local attention uses a different base frequency rope_local_base_freq = 10000.0 @@ -2202,8 +2197,8 @@ def register_parameter(cls, name, param): # For local attention with scale=1.0 and base=10000: # inv_freq = 1.0 / (10000 ** (arange(0, dim, 2) / dim)) dim = config.head_size # rotary_embedding_dim = head_size * rotary_pct = 128 - expected_local_inv_freq = 1.0 / (config.rope_local_base_freq**( - np.arange(0, dim, 2) / dim)) + expected_local_inv_freq = 1.0 / (config.rope_local_base_freq + **(np.arange(0, dim, 2) / dim)) np.testing.assert_allclose( local_inv_freq, @@ -2214,14 +2209,15 @@ def register_parameter(cls, name, param): # For global attention with linear scaling (factor=8.0): # scale = 1.0 / 8.0 = 0.125 # inv_freq = 0.125 / (1000000 ** (arange(0, dim, 2) / dim)) - expected_global_inv_freq = (1.0 / 8.0) / (config.rotary_base**( - np.arange(0, dim, 2) / dim)) + expected_global_inv_freq = (1.0 / 8.0) / (config.rotary_base** + (np.arange(0, dim, 2) / dim)) np.testing.assert_allclose( global_inv_freq, expected_global_inv_freq, rtol=1e-5, - err_msg="Global rotary_inv_freq should be computed WITH linear scaling") + err_msg= + "Global rotary_inv_freq should be computed WITH linear scaling") if __name__ == '__main__': From 81ff881b4df05fdab07a55021786310fce7073a0 Mon Sep 17 00:00:00 2001 From: Shiv Ghai <8965168+shivghai@users.noreply.github.com> Date: Fri, 19 Dec 2025 13:42:17 +0000 Subject: [PATCH 4/5] update Signed-off-by: Shiv Ghai <8965168+shivghai@users.noreply.github.com> --- tensorrt_llm/layers/attention.py | 20 ++++++++++---------- tests/unittest/others/test_layer.py | 7 +++++-- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/tensorrt_llm/layers/attention.py b/tensorrt_llm/layers/attention.py index eb3fa703f64..e9f298836ff 100755 --- a/tensorrt_llm/layers/attention.py +++ b/tensorrt_llm/layers/attention.py @@ -702,22 +702,17 @@ def create_attention_const_params(model_cls, config): is_buffer=True)) else: - def register_rope_params(rotary_base, - names_to_register, - is_local=False): + def register_rope_params(rotary_base, scale, scale_type, scaling, + names_to_register): # Rotary const weights. embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions( max_position_embeddings, rotary_embedding_dim, ) - # For local attention, use no scaling (consistent with forward pass) - local_scale = 1.0 if is_local else rotary_embedding_scale - local_scale_type = RotaryScalingType.none if is_local else rotary_embedding_scale_type - local_scaling = None if is_local else rotary_embedding_scaling rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin( max_position_embeddings, rotary_embedding_dim, rotary_base, - local_scale, local_scale_type, local_scaling) + scale, scale_type, scaling) model_cls.register_parameter( names_to_register[0], Parameter(embed_positions, dtype='float32', is_buffer=True)) @@ -731,6 +726,9 @@ def register_rope_params(rotary_base, is_buffer=True)) register_rope_params(rotary_base=rotary_embedding_base, + scale=rotary_embedding_scale, + scale_type=rotary_embedding_scale_type, + scaling=rotary_embedding_scaling, names_to_register=[ 'embed_positions', 'rotary_inv_freq', 'embed_positions_for_gpt_attention' @@ -742,11 +740,13 @@ def register_rope_params(rotary_base, if rotary_embedding_base_local is not None: register_rope_params( rotary_base=rotary_embedding_base_local, + scale=1.0, + scale_type=RotaryScalingType.none, + scaling=None, names_to_register=[ 'embed_positions_local', 'rotary_inv_freq_local', 'embed_positions_for_gpt_attention_local' - ], - is_local=True) + ]) @staticmethod def fill_attention_params(model_cls, attention_params): diff --git a/tests/unittest/others/test_layer.py b/tests/unittest/others/test_layer.py index 7afcb28eded..5f80622af1d 100644 --- a/tests/unittest/others/test_layer.py +++ b/tests/unittest/others/test_layer.py @@ -2135,11 +2135,14 @@ class MockGemma3Config: head_size = 128 max_position_embeddings = 32768 position_embedding_type = PositionEmbeddingType.rope_gpt_neox - rotary_base = 1000000.0 + # Use small rotary base values to avoid numerical instability in tests. + # Large bases (e.g. 1000000) get exponentiated, causing potential flakiness + # when comparing floating point results. + rotary_base = 100.0 rotary_scaling = {"factor": 8.0, "rope_type": "linear"} rotary_pct = 1.0 # Local attention uses a different base frequency - rope_local_base_freq = 10000.0 + rope_local_base_freq = 10.0 # Create a mock model class to receive registered parameters class MockModelCls: From 025b7c7234290f03e0afc69debe1ca6c59ce6827 Mon Sep 17 00:00:00 2001 From: Shiv Ghai <8965168+shivghai@users.noreply.github.com> Date: Mon, 22 Dec 2025 01:49:53 +0000 Subject: [PATCH 5/5] address comments Signed-off-by: Shiv Ghai <8965168+shivghai@users.noreply.github.com> --- tensorrt_llm/layers/attention.py | 30 ++++++++++++++++------------- tests/unittest/others/test_layer.py | 7 +++---- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/tensorrt_llm/layers/attention.py b/tensorrt_llm/layers/attention.py index e9f298836ff..29b63a4258c 100755 --- a/tensorrt_llm/layers/attention.py +++ b/tensorrt_llm/layers/attention.py @@ -702,7 +702,9 @@ def create_attention_const_params(model_cls, config): is_buffer=True)) else: - def register_rope_params(rotary_base, scale, scale_type, scaling, + def register_rope_params(rotary_base, rotary_embedding_scale, + rotary_embedding_scale_type, + rotary_embedding_scaling, names_to_register): # Rotary const weights. embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions( @@ -712,7 +714,8 @@ def register_rope_params(rotary_base, scale, scale_type, scaling, rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin( max_position_embeddings, rotary_embedding_dim, rotary_base, - scale, scale_type, scaling) + rotary_embedding_scale, rotary_embedding_scale_type, + rotary_embedding_scaling) model_cls.register_parameter( names_to_register[0], Parameter(embed_positions, dtype='float32', is_buffer=True)) @@ -725,14 +728,15 @@ def register_rope_params(rotary_base, scale, scale_type, scaling, dtype='float32', is_buffer=True)) - register_rope_params(rotary_base=rotary_embedding_base, - scale=rotary_embedding_scale, - scale_type=rotary_embedding_scale_type, - scaling=rotary_embedding_scaling, - names_to_register=[ - 'embed_positions', 'rotary_inv_freq', - 'embed_positions_for_gpt_attention' - ]) + register_rope_params( + rotary_base=rotary_embedding_base, + rotary_embedding_scale=rotary_embedding_scale, + rotary_embedding_scale_type=rotary_embedding_scale_type, + rotary_embedding_scaling=rotary_embedding_scaling, + names_to_register=[ + 'embed_positions', 'rotary_inv_freq', + 'embed_positions_for_gpt_attention' + ]) # For models with non-homegeneous attention layers requiring a second set of rope params. e.g. Gemma3. rotary_embedding_base_local = getattr(config, @@ -740,9 +744,9 @@ def register_rope_params(rotary_base, scale, scale_type, scaling, if rotary_embedding_base_local is not None: register_rope_params( rotary_base=rotary_embedding_base_local, - scale=1.0, - scale_type=RotaryScalingType.none, - scaling=None, + rotary_embedding_scale=1.0, + rotary_embedding_scale_type=RotaryScalingType.none, + rotary_embedding_scaling=None, names_to_register=[ 'embed_positions_local', 'rotary_inv_freq_local', 'embed_positions_for_gpt_attention_local' diff --git a/tests/unittest/others/test_layer.py b/tests/unittest/others/test_layer.py index 5f80622af1d..38bb7f1ef70 100644 --- a/tests/unittest/others/test_layer.py +++ b/tests/unittest/others/test_layer.py @@ -2184,7 +2184,6 @@ def register_parameter(cls, name, param): # The global and local inv_freq should be different because: # 1. Global uses rope_scaling with factor=8.0 (linear scaling applies 1/8 to inv_freq) # 2. Local uses scale=1.0 (no scaling) - # Also they use different base frequencies (1000000 vs 10000) self.assertFalse( np.allclose(global_inv_freq, local_inv_freq), "Global and local rotary_inv_freq should be different " @@ -2197,8 +2196,8 @@ def register_parameter(cls, name, param): "(global has scaling, local does not)") # Additional verification: Check that local inv_freq matches unscaled calculation - # For local attention with scale=1.0 and base=10000: - # inv_freq = 1.0 / (10000 ** (arange(0, dim, 2) / dim)) + # For local attention with scale=1.0 and base=10: + # inv_freq = 1.0 / (10 ** (arange(0, dim, 2) / dim)) dim = config.head_size # rotary_embedding_dim = head_size * rotary_pct = 128 expected_local_inv_freq = 1.0 / (config.rope_local_base_freq **(np.arange(0, dim, 2) / dim)) @@ -2211,7 +2210,7 @@ def register_parameter(cls, name, param): # For global attention with linear scaling (factor=8.0): # scale = 1.0 / 8.0 = 0.125 - # inv_freq = 0.125 / (1000000 ** (arange(0, dim, 2) / dim)) + # inv_freq = 0.125 / (100 ** (arange(0, dim, 2) / dim)) expected_global_inv_freq = (1.0 / 8.0) / (config.rotary_base** (np.arange(0, dim, 2) / dim))