diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 81697d86ab04..8042716ef83d 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -2296,6 +2296,18 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): # of the (i-1)th layer as aux hidden state self.model.layers_to_capture = [val + 1 for val in layer_ids] + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + + self.capture_aux_hidden_states = True + self.model.layers_to_capture = [val + 1 for val in layer_ids] + class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index c2e1ac233028..297371ec64b8 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -1175,6 +1175,9 @@ def _load_normal_weights( def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + def set_embed_and_head(self, embed, head): del self.model.embed_tokens.weight del self.lm_head.weight @@ -1197,6 +1200,18 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): # of the (i-1)th layer as aux hidden state self.model.layers_to_capture = [val + 1 for val in layer_ids] + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + + self.capture_aux_hidden_states = True + self.model.layers_to_capture = [val + 1 for val in layer_ids] + @classmethod def get_model_config_for_expert_location(cls, config): return ModelConfigForExpertLocation( diff --git a/python/sglang/srt/models/kimi_k25.py b/python/sglang/srt/models/kimi_k25.py index c5bfbac1cf1e..7fd30d248fcc 100644 --- a/python/sglang/srt/models/kimi_k25.py +++ b/python/sglang/srt/models/kimi_k25.py @@ -849,6 +849,30 @@ def set_eagle3_layers_to_capture( self.language_model.set_eagle3_layers_to_capture(layer_ids) + def set_dflash_layers_to_capture(self, layer_ids: List[int]) -> None: + """Set the layers to capture for DFLASH draft model training.""" + if not hasattr(self.language_model, "set_dflash_layers_to_capture"): + raise AttributeError( + "language_model does not support DFLASH layer capture." + ) + + self.language_model.set_dflash_layers_to_capture(layer_ids) + + def get_input_embeddings(self): + if not hasattr(self.language_model, "get_input_embeddings"): + raise AttributeError( + "language_model does not support get_input_embeddings()." + ) + + return self.language_model.get_input_embeddings() + + @property + def lm_head(self): + if not hasattr(self.language_model, "lm_head"): + raise AttributeError("language_model does not expose lm_head.") + + return self.language_model.lm_head + def get_embed_and_head(self) -> Tuple[torch.Tensor, torch.Tensor]: """Get embedding and LM head weights for speculative decoding.""" if not hasattr(self.language_model, "get_embed_and_head"): diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 7e557a8d5b31..f3bca7a6bb70 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -686,5 +686,19 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): else: self.model.layers_to_capture = [val + 1 for val in layer_ids] + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + + self.capture_aux_hidden_states = True + # SGLang captures "before layer i". To capture the hidden state after target + # layer `k` (HF-style), we capture before layer `k + 1`. + self.model.layers_to_capture = [val + 1 for val in layer_ids] + EntryClass = Qwen3ForCausalLM diff --git a/python/sglang/srt/models/qwen3_5.py b/python/sglang/srt/models/qwen3_5.py index 53ed08a4cd7e..cd1e68592bbb 100644 --- a/python/sglang/srt/models/qwen3_5.py +++ b/python/sglang/srt/models/qwen3_5.py @@ -574,8 +574,15 @@ def forward( ): forward_batch = kwargs.get("forward_batch", None) - hidden_states, residual = self.layer_communicator.prepare_attn( - hidden_states, residual, forward_batch + hidden_states, residual = ( + self.layer_communicator.prepare_attn_and_capture_last_layer_outputs( + hidden_states, + residual, + forward_batch, + captured_last_layer_outputs=kwargs.get( + "captured_last_layer_outputs", None + ), + ) ) if not forward_batch.forward_mode.is_idle(): @@ -825,10 +832,16 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], forward_batch: ForwardBatch, + captured_last_layer_outputs: Optional[list[torch.Tensor]] = None, **kwargs, ): - hidden_states, residual = self.layer_communicator.prepare_attn( - hidden_states, residual, forward_batch + hidden_states, residual = ( + self.layer_communicator.prepare_attn_and_capture_last_layer_outputs( + hidden_states, + residual, + forward_batch, + captured_last_layer_outputs=captured_last_layer_outputs, + ) ) if not forward_batch.forward_mode.is_idle(): @@ -945,9 +958,16 @@ def get_layer(idx: int, prefix: str): else: self.norm = PPMissingLayer() + self.layers_to_capture = [] + def get_input_embeddings(self): return self.embed_tokens + def set_dflash_layers_to_capture(self, layers_to_capture: list[int]): + self.layers_to_capture = layers_to_capture + for layer_id in self.layers_to_capture: + setattr(self.layers[layer_id], "_is_layer_to_capture", True) + @property def start_layer(self) -> int: return self._start_layer @@ -978,6 +998,7 @@ def forward( hidden_states = pp_proxy_tensors["hidden_states"] residual = pp_proxy_tensors["residual"] + aux_hidden_states = [] # Pass through decoder layers for layer_idx in range(self.start_layer, self.end_layer): layer = self.layers[layer_idx] @@ -989,6 +1010,11 @@ def forward( hidden_states=hidden_states, residual=residual, forward_batch=forward_batch, + captured_last_layer_outputs=( + aux_hidden_states + if getattr(layer, "_is_layer_to_capture", False) + else None + ), ) # Process deepstack embeddings if provided @@ -1018,7 +1044,10 @@ def forward( else: hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + if len(aux_hidden_states) == 0: + return hidden_states + + return hidden_states, aux_hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 912891b6a7eb..cfd52b65473f 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -924,6 +924,11 @@ def __init__( alt_stream=alt_stream, ) + def set_dflash_layers_to_capture(self, layers_to_capture: List[int]): + self.layers_to_capture = layers_to_capture + for layer_id in self.layers_to_capture: + setattr(self.layers[layer_id], "_is_layer_to_capture", True) + class Qwen3MoeForCausalLM(nn.Module): fall_back_to_pt_during_load = False @@ -1079,6 +1084,18 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): else: self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids]) + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + + self.capture_aux_hidden_states = True + self.model.set_dflash_layers_to_capture([val + 1 for val in layer_ids]) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 7e3862a8a61b..ac63cae10ea8 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -813,6 +813,11 @@ def set_eagle3_layers_to_capture(self, layers_to_capture: list[int]): for layer_id in self.layers_to_capture: setattr(self.layers[layer_id], "_is_layer_to_capture", True) + def set_dflash_layers_to_capture(self, layers_to_capture: list[int]): + self.layers_to_capture = layers_to_capture + for layer_id in self.layers_to_capture: + setattr(self.layers[layer_id], "_is_layer_to_capture", True) + def forward( self, input_ids: torch.Tensor, @@ -947,6 +952,9 @@ def forward( def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + def set_embed_and_head(self, embed, head): del self.model.embed_tokens.weight del self.lm_head.weight @@ -1127,5 +1135,17 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[list[int]] = None): else: self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids]) + def set_dflash_layers_to_capture(self, layer_ids: list[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + + self.capture_aux_hidden_states = True + self.model.set_dflash_layers_to_capture([val + 1 for val in layer_ids]) + EntryClass = Qwen3NextForCausalLM diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 7746b2445999..cf5875f97600 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -1122,6 +1122,7 @@ def __init__( self.logits_processor = LogitsProcessor(self.config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self.capture_aux_hidden_states = False # like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on # 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states @@ -1246,6 +1247,10 @@ def forward( pp_proxy_tensors=pp_proxy_tensors, ) + aux_hidden_states = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = hidden_states + if self.pp_group.is_last_rank: if not get_embedding: return self.logits_processor( @@ -1253,12 +1258,23 @@ def forward( hidden_states, self.lm_head, forward_batch, + aux_hidden_states, ) else: return self.pooler(hidden_states, forward_batch) else: return hidden_states + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + self.capture_aux_hidden_states = True + self.model.set_dflash_layers_to_capture([val + 1 for val in layer_ids]) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id)