Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.model_executor.models.interfaces import EagleModelMixin
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.sequence import IntermediateTensors


class PredictableLlamaModel(nn.Module):
class PredictableLlamaModel(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.aux_hidden_state_layers = tuple[int, ...]()

# Create minimal embed_tokens for embedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down
21 changes: 6 additions & 15 deletions vllm/model_executor/models/afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.interfaces import (
EagleModelMixin,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
Expand Down Expand Up @@ -384,7 +385,7 @@ def forward(
"inputs_embeds": 0,
}
)
class AfmoeModel(nn.Module):
class AfmoeModel(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

Expand Down Expand Up @@ -421,8 +422,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
else:
self.norm = PPMissingLayer()

self.aux_hidden_state_layers = tuple[int, ...]()

self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
Expand Down Expand Up @@ -453,15 +452,14 @@ def forward(
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to use start layer to handle pp case

Suggested change
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
aux_hidden_states = self._maybe_add_hidden_state([], self.start_layer, hidden_states, residual)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in description, see also #36151. Let me know if you think it would be better to apply the fix to all the models in this PR.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, it seems like fixing this likely won't be enough to get PP + spec decode working, since it seems like we aren't transferring aux_hidden_states across PP ranks in gpu model runner. So this will probably require a larger fix.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed the bug accordingly.

for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states + residual if residual is not None else hidden_states
)
hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
Comment thread
benchislett marked this conversation as resolved.

if not get_pp_group().is_last_rank:
return IntermediateTensors(
Expand Down Expand Up @@ -691,13 +689,6 @@ def set_eplb_state(
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)

def forward(
self,
input_ids: torch.Tensor | None,
Expand Down
30 changes: 15 additions & 15 deletions vllm/model_executor/models/apertus.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,13 @@
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionType

from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import (
EagleModelMixin,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
Expand Down Expand Up @@ -313,7 +319,7 @@ def forward(


@support_torch_compile
class ApertusModel(nn.Module):
class ApertusModel(nn.Module, EagleModelMixin):
def __init__(
self,
*,
Expand Down Expand Up @@ -357,8 +363,6 @@ def __init__(
else:
self.norm = PPMissingLayer()

self.aux_hidden_state_layers = tuple[int, ...]()

self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
Expand All @@ -384,13 +388,14 @@ def forward(
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
Comment thread
benchislett marked this conversation as resolved.

if not get_pp_group().is_last_rank:
return IntermediateTensors(
Expand Down Expand Up @@ -472,7 +477,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
return loaded_params


class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class ApertusForCausalLM(
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}

# LoRA specific attributes
Expand Down Expand Up @@ -520,13 +527,6 @@ def __init__(
self.model.make_empty_intermediate_tensors
)

def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)

def _init_model(
self,
vllm_config: VllmConfig,
Expand Down
27 changes: 15 additions & 12 deletions vllm/model_executor/models/arcee.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
)
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import (
EagleModelMixin,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
Expand Down Expand Up @@ -170,7 +176,7 @@ def forward(


@support_torch_compile
class ArceeModel(nn.Module):
class ArceeModel(nn.Module, EagleModelMixin):
"""The transformer model backbone for Arcee (embedding layer + stacked
decoder blocks + final norm)."""

Expand Down Expand Up @@ -218,10 +224,6 @@ def __init__(
else:
self.norm = PPMissingLayer()

# For optional capturing of intermediate hidden states
# (not used by default)
self.aux_hidden_state_layers: tuple[int, ...] = tuple()

# Prepare factory for empty intermediate tensors
# (for pipeline scheduling)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
Expand Down Expand Up @@ -253,15 +255,14 @@ def forward(
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

aux_hidden_states: list[torch.Tensor] = []
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states + residual
) # capture pre-layer hidden state if needed
hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
Comment thread
benchislett marked this conversation as resolved.

if not get_pp_group().is_last_rank:
# Send intermediate results to the next pipeline stage
Expand Down Expand Up @@ -348,7 +349,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
return loaded_params


class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class ArceeForCausalLM(
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
"""Arcee Model for causal language modeling, integrated with vLLM
runtime."""

Expand Down
29 changes: 15 additions & 14 deletions vllm/model_executor/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,13 @@
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionType

from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .interfaces import (
EagleModelMixin,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
Expand Down Expand Up @@ -256,7 +262,7 @@ def forward(


@support_torch_compile
class GptOssModel(nn.Module):
class GptOssModel(nn.Module, EagleModelMixin):
def __init__(
self,
*,
Expand Down Expand Up @@ -285,7 +291,6 @@ def __init__(
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size
)
self.aux_hidden_state_layers = tuple[int, ...]()

def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embedding(input_ids)
Expand All @@ -309,12 +314,13 @@ def forward(
x = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state(
[], self.start_layer, x, residual
)
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
if i in self.aux_hidden_state_layers:
aux_hidden_states.append(x if residual is None else x + residual)
x, residual = layer(x, positions, residual)
self._maybe_add_hidden_state(aux_hidden_states, i + 1, x, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": x, "residual": residual})
x, _ = self.norm(x, residual)
Expand Down Expand Up @@ -1141,7 +1147,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
)


class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
class GptOssForCausalLM(
nn.Module, SupportsPP, SupportsEagle, SupportsEagle3, SupportsLoRA
):
is_3d_moe_weight: bool = True
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}

Expand Down Expand Up @@ -1197,13 +1205,6 @@ def __init__(
self.model.make_empty_intermediate_tensors
)

def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)

def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

Expand Down
32 changes: 17 additions & 15 deletions vllm/model_executor/models/hunyuan_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionType

from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
from .interfaces import (
EagleModelMixin,
MixtureOfExperts,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
Expand Down Expand Up @@ -586,7 +593,7 @@ def forward(
"inputs_embeds": 0,
}
)
class HunYuanModel(nn.Module):
class HunYuanModel(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

Expand Down Expand Up @@ -629,7 +636,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers = tuple[int, ...]()

def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
Expand All @@ -654,13 +660,10 @@ def forward(

cla_factor = _get_cla_factor(self.config)
prev_kv_states = None
aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for i, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if i in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)

hidden_states, residual, kv_states = layer(
positions,
hidden_states,
Expand All @@ -673,6 +676,10 @@ def forward(
else:
prev_kv_states = None

self._maybe_add_hidden_state(
aux_hidden_states, i + 1, hidden_states, residual
)

if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
Expand Down Expand Up @@ -904,7 +911,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
return loaded_params


class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
class HunyuanV1ModelBase(
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -943,13 +952,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
else:
self.lm_head = PPMissingLayer()

def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)

def forward(
self,
input_ids: torch.Tensor | None,
Expand Down
Loading
Loading