diff --git a/python/sglang/srt/models/gemma3_causal.py b/python/sglang/srt/models/gemma3_causal.py index d892c515254..d9e0293b76a 100644 --- a/python/sglang/srt/models/gemma3_causal.py +++ b/python/sglang/srt/models/gemma3_causal.py @@ -47,6 +47,12 @@ from sglang.srt.utils import add_prefix, make_layers +# Aligned with HF's implementation, using sliding window inclusive with the last token +# SGLang assumes exclusive +def get_attention_sliding_window_size(config): + return config.sliding_window - 1 + + # Adapted from: # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py def extract_layer_index(prefix: str) -> int: @@ -170,7 +176,7 @@ def __init__( self.rope_scaling = {"rope_type": "default"} # FIXME(mick): idk why vllm does this # self.sliding_window = config.interleaved_sliding_window - self.sliding_window = config.sliding_window + self.sliding_window = get_attention_sliding_window_size(config) else: # Global attention. Use the values in config.json. self.rope_theta = config.rope_theta @@ -184,6 +190,8 @@ def __init__( num_kv_heads=self.num_kv_heads, layer_id=layer_id, logit_cap=getattr(self.config, "attn_logit_softcapping", None), + # Module must also define `get_attention_sliding_window_size` to correctly initialize + # attention backend in `ForwardBatch`. sliding_window_size=self.sliding_window, prefix=add_prefix("attn", prefix), ) @@ -609,6 +617,9 @@ def __init__( def get_input_embeddings(self) -> nn.Embedding: return self.model.embed_tokens + def get_attention_sliding_window_size(self): + return get_attention_sliding_window_size(self.config) + def dtype(self) -> torch.dtype: return next(self.parameters()).dtype @@ -621,7 +632,6 @@ def forward( input_embeds: torch.Tensor = None, **kwargs, ) -> LogitsProcessor: - hidden_states = self.model( input_ids, positions, forward_batch, input_embeds, **kwargs ) diff --git a/python/sglang/srt/models/gemma3_mm.py b/python/sglang/srt/models/gemma3_mm.py index c357bf9e595..80dd7197a37 100644 --- a/python/sglang/srt/models/gemma3_mm.py +++ b/python/sglang/srt/models/gemma3_mm.py @@ -268,6 +268,12 @@ def prepare_attn_masks( def get_input_embeddings(self) -> nn.Embedding: return self.language_model.get_input_embeddings() + def get_attention_sliding_window_size(self): + """ + This value is used to initialize attention backends in `ForwardBatch`. + """ + return self.language_model.get_attention_sliding_window_size() + def get_image_feature(self, image_input: MultimodalInputs): """ Projects the last hidden state from the vision model into language model space.