Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion python/sglang/srt/models/gemma4_causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbeddingShardIndices,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
Expand All @@ -68,6 +71,25 @@ def get_attention_sliding_window_size(config):
Gemma4TextScaledWordEmbedding = Gemma3TextScaledWordEmbedding


def _ensure_dflash_shard_indices(lm_head, vocab_size: int) -> None:
"""Inject a tp=1 ShardIndices namespace into a non-vocab-parallel lm_head
so the SGLang DFLASH worker passes its `hasattr(lm_head, 'shard_indices')`
gate. Gemma 4 uses a tied embed_tokens (nn.Embedding) for lm_head, which
has no shard_indices on its own."""
if getattr(lm_head, "shard_indices", None) is not None:
return
lm_head.shard_indices = VocabParallelEmbeddingShardIndices(
padded_org_vocab_start_index=0,
padded_org_vocab_end_index=vocab_size,
padded_added_vocab_start_index=vocab_size,
padded_added_vocab_end_index=vocab_size,
org_vocab_start_index=0,
org_vocab_end_index=vocab_size,
added_vocab_start_index=vocab_size,
added_vocab_end_index=vocab_size,
)


class Gemma4Router(nn.Module):
"""Router for Gemma4 MoE that preprocesses input before projection.

Expand Down Expand Up @@ -934,6 +956,7 @@ def __init__(
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
_ensure_dflash_shard_indices(self.lm_head, config.vocab_size)
self.capture_aux_hidden_states = False
self.post_init()

Expand Down Expand Up @@ -1135,5 +1158,15 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
# of the (i-1)th layer as aux hidden state
self.model.layers_to_capture = [val + 1 for val in layer_ids]

def set_dflash_layers_to_capture(self, layer_ids: List[int]):
if layer_ids is None:
raise ValueError(
"DFLASH requires explicit layer_ids for aux hidden capture."
)
self.capture_aux_hidden_states = True
# we plus 1 here because in sglang, for the ith layer, it takes the output
# of the (i-1)th layer as aux hidden state
self.model.layers_to_capture = [val + 1 for val in layer_ids]


EntryClass = Gemma4ForCausalLM
21 changes: 20 additions & 1 deletion python/sglang/srt/models/gemma4_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@
maybe_remap_kv_scale_name,
)
from sglang.srt.models.gemma4_audio import Gemma4AudioEncoder
from sglang.srt.models.gemma4_causal import Gemma4TextModel
from sglang.srt.models.gemma4_causal import (
Gemma4TextModel,
_ensure_dflash_shard_indices,
)
from sglang.srt.models.gemma4_vision import Gemma4VisionEncoder
from sglang.srt.utils import add_prefix
from sglang.srt.utils.hf_transformers_utils import get_processor
Expand Down Expand Up @@ -218,6 +221,12 @@ def __init__(
quant_config,
prefix=add_prefix("language_model", prefix),
)
# Gemma 4 ties lm_head to the text embed_tokens; expose it so that
# speculative-decoding workers (e.g. DFLASH) can locate a vocab-parallel
# head. _ensure_dflash_shard_indices injects a tp=1 ShardIndices
# namespace onto the tied embedding for DFLASH's fast greedy path.
self.lm_head = self.language_model.embed_tokens
_ensure_dflash_shard_indices(self.lm_head, config.text_config.vocab_size)

# Create logits processor for the multimodal model
self.logits_processor = LogitsProcessor(config.text_config)
Expand Down Expand Up @@ -967,5 +976,15 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
# of the (i-1)th layer as aux hidden state
self.language_model.layers_to_capture = [val + 1 for val in layer_ids]

def set_dflash_layers_to_capture(self, layer_ids: List[int]):
if layer_ids is None:
raise ValueError(
"DFLASH requires explicit layer_ids for aux hidden capture."
)
self.capture_aux_hidden_states = True
# we plus 1 here because in sglang, for the ith layer, it takes the output
# of the (i-1)th layer as aux hidden state
self.language_model.layers_to_capture = [val + 1 for val in layer_ids]


EntryClass = Gemma4ForConditionalGeneration
Loading