Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/kernels/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def test_mixtral_moe(

# need to override the forward context for unittests, otherwise it assumes
# we're running the model forward pass (the model specified in vllm_config)
get_forward_context().remaining_moe_layers = None
get_forward_context().all_moe_layers = None

# Run forward passes for both MoE blocks
hf_states, _ = hf_moe.forward(hf_inputs)
Expand Down
4 changes: 4 additions & 0 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,10 @@ class CompilationConfig:
Map from layer name to layer objects that need to be accessed outside
model code, e.g., Attention, FusedMOE when dp_size>1."""

static_all_moe_layers: list[str] = field(default_factory=list, init=False)
"""The names of all the MOE layers in the model
"""

# Attention ops; used for piecewise cudagraphs
# Use PyTorch operator format: "namespace::name"
_attention_ops: ClassVar[list[str]] = [
Expand Down
21 changes: 8 additions & 13 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,11 @@ class ForwardContext:
# the graph.
#
# The workaround is to store a list of the strings that each of those
# custom ops needs, in reverse order, in the ForwardContext.
# custom ops needs in the ForwardContext (all_moe_layers)
# as well as a counter (moe_layer_index).
# The ForwardContext object is alive for the duration of the forward pass.
# When the custom op needs the string, pop the string from this list.
# When the custom op needs a layer string, get the next string
# from all_moe_layers and increment the counter.
#
# This assumes that the custom operators will always be executed in
# order and that torch.compile will not try to reorder these
Expand All @@ -233,7 +235,8 @@ class ForwardContext:
#
# If this value is None (like in some tests), then we end up baking the string
# into the graph. Otherwise, the moe custom ops will pop a string from this list.
remaining_moe_layers: list[str] | None = None
all_moe_layers: list[str] | None = None
moe_layer_index: int = 0

additional_kwargs: dict[str, Any] = field(default_factory=dict)

Expand Down Expand Up @@ -271,17 +274,9 @@ def create_forward_context(
additional_kwargs: dict[str, Any] | None = None,
skip_compiled: bool = False,
):
no_compile_layers = vllm_config.compilation_config.static_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE

remaining_moe_layers = [
name for name, layer in no_compile_layers.items() if isinstance(layer, FusedMoE)
]
remaining_moe_layers.reverse()

return ForwardContext(
no_compile_layers=no_compile_layers,
remaining_moe_layers=remaining_moe_layers,
no_compile_layers=vllm_config.compilation_config.static_forward_context,
all_moe_layers=vllm_config.compilation_config.static_all_moe_layers,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
slot_mapping=slot_mapping or {},
Expand Down
13 changes: 9 additions & 4 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def __init__(
if prefix in compilation_config.static_forward_context:
raise ValueError("Duplicate layer name: {}".format(prefix))
compilation_config.static_forward_context[prefix] = self
compilation_config.static_all_moe_layers.append(prefix)
self.layer_name = prefix

self.enable_eplb = enable_eplb
Expand Down Expand Up @@ -1566,7 +1567,7 @@ def encode_layer_name() -> str:
# Can be unavailable or None in unittests
if (
is_forward_context_available()
and get_forward_context().remaining_moe_layers is not None
and get_forward_context().all_moe_layers is not None
):
return "from_forward_context"
return self.layer_name
Expand Down Expand Up @@ -1987,13 +1988,17 @@ def extra_repr(self) -> str:
def get_layer_from_name(layer_name: str) -> FusedMoE:
forward_context: ForwardContext = get_forward_context()
if layer_name == "from_forward_context":
if not forward_context.remaining_moe_layers:
all_moe_layers = forward_context.all_moe_layers
assert all_moe_layers is not None
moe_layer_index = forward_context.moe_layer_index
if moe_layer_index >= len(all_moe_layers):
raise AssertionError(
"We expected the number of MOE layers in `remaining_moe_layers` "
"We expected the number of MOE layers in `all_moe_layers` "
"to be equal to the number of "
"{vllm.moe_forward, vllm.moe_forward_shared} calls."
)
layer_name = forward_context.remaining_moe_layers.pop()
layer_name = all_moe_layers[moe_layer_index]
forward_context.moe_layer_index += 1
self = cast(FusedMoE, forward_context.no_compile_layers[layer_name])
return self

Expand Down