Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2296,6 +2296,18 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
# of the (i-1)th layer as aux hidden state
self.model.layers_to_capture = [val + 1 for val in layer_ids]

def set_dflash_layers_to_capture(self, layer_ids: List[int]):
if not self.pp_group.is_last_rank:
return

if layer_ids is None:
raise ValueError(
"DFLASH requires explicit layer_ids for aux hidden capture."
)

self.capture_aux_hidden_states = True
self.model.layers_to_capture = [val + 1 for val in layer_ids]


class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass
Expand Down
15 changes: 15 additions & 0 deletions python/sglang/srt/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,9 @@ def _load_normal_weights(
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight

def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens

def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
Expand All @@ -1197,6 +1200,18 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
# of the (i-1)th layer as aux hidden state
self.model.layers_to_capture = [val + 1 for val in layer_ids]

def set_dflash_layers_to_capture(self, layer_ids: List[int]):
if not self.pp_group.is_last_rank:
return

if layer_ids is None:
raise ValueError(
"DFLASH requires explicit layer_ids for aux hidden capture."
)

self.capture_aux_hidden_states = True
self.model.layers_to_capture = [val + 1 for val in layer_ids]

@classmethod
def get_model_config_for_expert_location(cls, config):
return ModelConfigForExpertLocation(
Expand Down
24 changes: 24 additions & 0 deletions python/sglang/srt/models/kimi_k25.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,30 @@ def set_eagle3_layers_to_capture(

self.language_model.set_eagle3_layers_to_capture(layer_ids)

def set_dflash_layers_to_capture(self, layer_ids: List[int]) -> None:
"""Set the layers to capture for DFLASH draft model training."""
if not hasattr(self.language_model, "set_dflash_layers_to_capture"):
raise AttributeError(
"language_model does not support DFLASH layer capture."
)

self.language_model.set_dflash_layers_to_capture(layer_ids)

def get_input_embeddings(self):
if not hasattr(self.language_model, "get_input_embeddings"):
raise AttributeError(
"language_model does not support get_input_embeddings()."
)

return self.language_model.get_input_embeddings()

@property
def lm_head(self):
if not hasattr(self.language_model, "lm_head"):
raise AttributeError("language_model does not expose lm_head.")

return self.language_model.lm_head

def get_embed_and_head(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get embedding and LM head weights for speculative decoding."""
if not hasattr(self.language_model, "get_embed_and_head"):
Expand Down
14 changes: 14 additions & 0 deletions python/sglang/srt/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,5 +686,19 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
else:
self.model.layers_to_capture = [val + 1 for val in layer_ids]

def set_dflash_layers_to_capture(self, layer_ids: List[int]):
if not self.pp_group.is_last_rank:
return

if layer_ids is None:
raise ValueError(
"DFLASH requires explicit layer_ids for aux hidden capture."
)

self.capture_aux_hidden_states = True
# SGLang captures "before layer i". To capture the hidden state after target
# layer `k` (HF-style), we capture before layer `k + 1`.
self.model.layers_to_capture = [val + 1 for val in layer_ids]


EntryClass = Qwen3ForCausalLM
39 changes: 34 additions & 5 deletions python/sglang/srt/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,8 +574,15 @@ def forward(
):
forward_batch = kwargs.get("forward_batch", None)

hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
hidden_states, residual = (
self.layer_communicator.prepare_attn_and_capture_last_layer_outputs(
hidden_states,
residual,
forward_batch,
captured_last_layer_outputs=kwargs.get(
"captured_last_layer_outputs", None
),
)
)

if not forward_batch.forward_mode.is_idle():
Expand Down Expand Up @@ -825,10 +832,16 @@ def forward(
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
forward_batch: ForwardBatch,
captured_last_layer_outputs: Optional[list[torch.Tensor]] = None,
**kwargs,
):
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
hidden_states, residual = (
self.layer_communicator.prepare_attn_and_capture_last_layer_outputs(
hidden_states,
residual,
forward_batch,
captured_last_layer_outputs=captured_last_layer_outputs,
)
)

if not forward_batch.forward_mode.is_idle():
Expand Down Expand Up @@ -945,9 +958,16 @@ def get_layer(idx: int, prefix: str):
else:
self.norm = PPMissingLayer()

self.layers_to_capture = []

def get_input_embeddings(self):
return self.embed_tokens

def set_dflash_layers_to_capture(self, layers_to_capture: list[int]):
self.layers_to_capture = layers_to_capture
for layer_id in self.layers_to_capture:
setattr(self.layers[layer_id], "_is_layer_to_capture", True)

@property
def start_layer(self) -> int:
return self._start_layer
Expand Down Expand Up @@ -978,6 +998,7 @@ def forward(
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]

aux_hidden_states = []
# Pass through decoder layers
for layer_idx in range(self.start_layer, self.end_layer):
layer = self.layers[layer_idx]
Expand All @@ -989,6 +1010,11 @@ def forward(
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
captured_last_layer_outputs=(
aux_hidden_states
if getattr(layer, "_is_layer_to_capture", False)
else None
),
)

# Process deepstack embeddings if provided
Expand Down Expand Up @@ -1018,7 +1044,10 @@ def forward(
else:
hidden_states, _ = self.norm(hidden_states, residual)

return hidden_states
if len(aux_hidden_states) == 0:
return hidden_states

return hidden_states, aux_hidden_states

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
Expand Down
17 changes: 17 additions & 0 deletions python/sglang/srt/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,11 @@ def __init__(
alt_stream=alt_stream,
)

def set_dflash_layers_to_capture(self, layers_to_capture: List[int]):
self.layers_to_capture = layers_to_capture
for layer_id in self.layers_to_capture:
setattr(self.layers[layer_id], "_is_layer_to_capture", True)


class Qwen3MoeForCausalLM(nn.Module):
fall_back_to_pt_during_load = False
Expand Down Expand Up @@ -1079,6 +1084,18 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
else:
self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids])

def set_dflash_layers_to_capture(self, layer_ids: List[int]):
if not self.pp_group.is_last_rank:
return

if layer_ids is None:
raise ValueError(
"DFLASH requires explicit layer_ids for aux hidden capture."
)

self.capture_aux_hidden_states = True
self.model.set_dflash_layers_to_capture([val + 1 for val in layer_ids])

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
Expand Down
20 changes: 20 additions & 0 deletions python/sglang/srt/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,11 @@ def set_eagle3_layers_to_capture(self, layers_to_capture: list[int]):
for layer_id in self.layers_to_capture:
setattr(self.layers[layer_id], "_is_layer_to_capture", True)

def set_dflash_layers_to_capture(self, layers_to_capture: list[int]):
self.layers_to_capture = layers_to_capture
for layer_id in self.layers_to_capture:
setattr(self.layers[layer_id], "_is_layer_to_capture", True)

def forward(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -947,6 +952,9 @@ def forward(
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight

def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens

def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
Expand Down Expand Up @@ -1127,5 +1135,17 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[list[int]] = None):
else:
self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids])

def set_dflash_layers_to_capture(self, layer_ids: list[int]):
if not self.pp_group.is_last_rank:
return

if layer_ids is None:
raise ValueError(
"DFLASH requires explicit layer_ids for aux hidden capture."
)

self.capture_aux_hidden_states = True
self.model.set_dflash_layers_to_capture([val + 1 for val in layer_ids])


EntryClass = Qwen3NextForCausalLM
16 changes: 16 additions & 0 deletions python/sglang/srt/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,7 @@ def __init__(

self.logits_processor = LogitsProcessor(self.config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.capture_aux_hidden_states = False
# like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on
# 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states

Expand Down Expand Up @@ -1246,19 +1247,34 @@ def forward(
pp_proxy_tensors=pp_proxy_tensors,
)

aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states

if self.pp_group.is_last_rank:
if not get_embedding:
return self.logits_processor(
input_ids,
hidden_states,
self.lm_head,
forward_batch,
aux_hidden_states,
)
else:
return self.pooler(hidden_states, forward_batch)
else:
return hidden_states

def set_dflash_layers_to_capture(self, layer_ids: List[int]):
if not self.pp_group.is_last_rank:
return
if layer_ids is None:
raise ValueError(
"DFLASH requires explicit layer_ids for aux hidden capture."
)
self.capture_aux_hidden_states = True
self.model.set_dflash_layers_to_capture([val + 1 for val in layer_ids])

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
Expand Down
Loading