5757from vllm .model_executor .sampling_metadata import SamplingMetadata
5858from vllm .sequence import IntermediateTensors
5959
60- from .interfaces import MixtureOfExperts , SupportsLoRA , SupportsPP
60+ from .interfaces import MixtureOfExperts , SupportsLoRA , SupportsPP , SupportsEagle3
6161from .utils import (AutoWeightsLoader , PPMissingLayer , extract_layer_index ,
6262 is_pp_missing_parameter ,
6363 make_empty_intermediate_tensors_factory , make_layers ,
@@ -408,6 +408,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
408408 make_empty_intermediate_tensors_factory (
409409 ["hidden_states" , "residual" ], config .hidden_size ))
410410
411+ self .aux_hidden_state_layers : tuple [int ] = tuple ()
412+
411413 def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
412414 return self .embed_tokens (input_ids )
413415
@@ -428,14 +430,21 @@ def forward(
428430 assert intermediate_tensors is not None
429431 hidden_states = intermediate_tensors ["hidden_states" ]
430432 residual = intermediate_tensors ["residual" ]
431- for layer in islice (self .layers , self .start_layer , self .end_layer ):
433+
434+ aux_hidden_states = []
435+ for layer_idx , layer in enumerate (islice (self .layers , self .start_layer , self .end_layer )):
436+ if layer_idx in self .aux_hidden_state_layers :
437+ aux_hidden_states .append (hidden_states + residual )
432438 hidden_states , residual = layer (positions , hidden_states , residual )
439+
433440 if not get_pp_group ().is_last_rank :
434441 return IntermediateTensors ({
435442 "hidden_states" : hidden_states ,
436443 "residual" : residual
437444 })
438445 hidden_states , _ = self .norm (hidden_states , residual )
446+ if len (aux_hidden_states ) > 0 :
447+ return hidden_states , aux_hidden_states
439448 return hidden_states
440449
441450 def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
@@ -579,7 +588,7 @@ def load_weights(self, weights: Iterable[tuple[str,
579588 return loaded_params
580589
581590
582- class Qwen3MoeForCausalLM (nn .Module , SupportsPP , SupportsLoRA ,
591+ class Qwen3MoeForCausalLM (nn .Module , SupportsPP , SupportsLoRA , SupportsEagle3 ,
583592 MixtureOfExperts ):
584593 packed_modules_mapping = {
585594 "qkv_proj" : [
@@ -674,6 +683,13 @@ def update_physical_experts_metadata(
674683
675684 def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
676685 return self .model .get_input_embeddings (input_ids )
686+
687+ def set_aux_hidden_state_layers (self , layers : tuple [int , ...]) -> None :
688+ self .model .aux_hidden_state_layers = layers
689+
690+ def get_eagle3_aux_hidden_state_layers (self ) -> tuple [int , ...]:
691+ num_layers = len (self .model .layers )
692+ return (2 , num_layers // 2 , num_layers - 3 )
677693
678694 def forward (
679695 self ,
0 commit comments