diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index e2904de9aa72..0f8adc6e4ac2 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -1003,6 +1003,7 @@ def _verify_args(self) -> Self: "kimi_k25", "minimax_m2", "gemma4", + "laguna", ] if ( self.method in ("eagle3", "extract_hidden_states", "dflash") diff --git a/vllm/model_executor/models/laguna.py b/vllm/model_executor/models/laguna.py index 08f35d691817..5bf5b7cb021e 100644 --- a/vllm/model_executor/models/laguna.py +++ b/vllm/model_executor/models/laguna.py @@ -39,7 +39,12 @@ default_weight_loader, maybe_remap_kv_scale_name, ) -from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.interfaces import ( + EagleModelMixin, + SupportsEagle3, + SupportsLoRA, + SupportsPP, +) from vllm.model_executor.models.utils import ( AutoWeightsLoader, PPMissingLayer, @@ -554,7 +559,7 @@ def forward( @support_torch_compile -class LagunaModel(nn.Module): +class LagunaModel(nn.Module, EagleModelMixin): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -633,8 +638,17 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in islice(self.layers, self.start_layer, self.end_layer): + aux_hidden_states = self._maybe_add_hidden_state( + [], self.start_layer, hidden_states, residual + ) + for layer_idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer), + start=self.start_layer, + ): hidden_states, residual = layer(positions, hidden_states, residual) + self._maybe_add_hidden_state( + aux_hidden_states, layer_idx + 1, hidden_states, residual + ) if not get_pp_group().is_last_rank: return IntermediateTensors( @@ -642,6 +656,8 @@ def forward( ) hidden_states, _ = self.norm(hidden_states, residual) + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: @@ -821,7 +837,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded_params -class LagunaForCausalLM(nn.Module, SupportsPP, SupportsLoRA): +class LagunaForCausalLM(nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3): fall_back_to_pt_during_load = False packed_modules_mapping = {