Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 106 additions & 56 deletions tests/model_executor/test_routed_experts_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,59 +185,109 @@ def capture(self, layer_id, topk_ids):
assert len(capturer.calls) == 1


def test_routed_experts_capturer_single_dp_no_metadata():
"""dp_metadata is None: capture writes the full topk_ids rows."""
capturer = _capturer_with_buffer(dp_rank=0)
topk = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.int32)
ctx = SimpleNamespace(dp_metadata=None)
with patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx):
capturer.capture(layer_id=0, topk_ids=topk)
assert torch.equal(capturer._device_buffer[:3, 0, :], topk)
assert capturer._device_buffer[3, 0, 0].item() == -1


def test_routed_experts_capturer_dp_naive_concatenated_all_ranks():
"""n == sum(num_tokens_dp): slice this rank's segment from concatenated topk."""
capturer = _capturer_with_buffer(dp_rank=1)
num_tokens_dp = torch.tensor([2, 3], dtype=torch.int32)
ctx = SimpleNamespace(
dp_metadata=SimpleNamespace(num_tokens_across_dp_cpu=num_tokens_dp)
)
# Concatenated order: rank0 rows then rank1 rows.
topk = torch.tensor(
[[0, 1], [2, 3], [10, 11], [12, 13], [14, 15]], dtype=torch.int32
)
with patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx):
capturer.capture(layer_id=0, topk_ids=topk)
want = topk[2:5]
assert torch.equal(capturer._device_buffer[:3, 0, :], want)


def test_routed_experts_capturer_dp_modular_local_tokens():
"""n == token_num_per_dp: topk is already local to this DP rank."""
capturer = _capturer_with_buffer(dp_rank=1)
num_tokens_dp = torch.tensor([2, 3], dtype=torch.int32)
ctx = SimpleNamespace(
dp_metadata=SimpleNamespace(num_tokens_across_dp_cpu=num_tokens_dp)
)
topk = torch.tensor([[10, 11], [12, 13], [14, 15]], dtype=torch.int32)
with patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx):
capturer.capture(layer_id=0, topk_ids=topk)
assert torch.equal(capturer._device_buffer[:3, 0, :], topk)


def test_routed_experts_capturer_dp_unexpected_batch_raises():
"""Mismatch between topk batch dim and DP layout: fail fast."""
capturer = _capturer_with_buffer(dp_rank=0)
num_tokens_dp = torch.tensor([2, 3], dtype=torch.int32)
ctx = SimpleNamespace(
dp_metadata=SimpleNamespace(num_tokens_across_dp_cpu=num_tokens_dp)
)
# total=5, local=2: n=1 matches neither naive (5) nor modular (2).
topk = torch.tensor([[1, 2]], dtype=torch.int32)
with (
patch(f"{_REC_MODULE}.get_forward_context", return_value=ctx),
pytest.raises(AssertionError, match="unexpected topk_ids batch dim"),
):
capturer.capture(layer_id=0, topk_ids=topk)
assert capturer._device_buffer[0, 0, 0].item() == -1
# =========================================================================
# Tests for device-cache routing replay architecture
# =========================================================================


class TestRoutedExpertsDeviceCache:
"""Tests for _RoutedExpertsDeviceCache (GPU buffer for routing data)."""

def test_allocation_shape_and_dtype(self):
"""Device cache allocates (L, N, K) int16 buffer."""
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
_RoutedExpertsDeviceCache,
)

cache = _RoutedExpertsDeviceCache(
num_hidden_layers=40,
max_num_batched_tokens=8192,
num_experts_per_tok=8,
)
assert cache.buffer.shape == (40, 8192, 8)
assert cache.buffer.dtype == torch.int16

def test_per_layer_view_is_contiguous(self):
"""buffer[layer_id] gives contiguous (N, K) view for FlashInfer."""
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
_RoutedExpertsDeviceCache,
)

cache = _RoutedExpertsDeviceCache(
num_hidden_layers=40,
max_num_batched_tokens=8192,
num_experts_per_tok=8,
)
layer_view = cache.buffer[0]
assert layer_view.is_contiguous()
assert layer_view.shape == (8192, 8)


class TestRoutedExpertsHostCache:
"""Tests for _RoutedExpertsHostCache (per-request numpy buffer)."""

def test_sentinel_initialization(self):
"""Host cache initializes with -1 sentinel (not zeros).

This is critical: expert ID 0 is valid, so we use -1 to mark
positions with no routing data (e.g., prefix-cached tokens).
"""
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
_RoutedExpertsHostCache,
)
import numpy as np

cache = _RoutedExpertsHostCache(
num_hidden_layers=40,
num_experts_per_tok=8,
)
buf = cache.get_or_grow_buffer("req1", max_pos=100)
assert buf.dtype == np.int16
assert (buf == -1).all(), "Host cache must initialize with -1, not zeros"

def test_grow_preserves_existing_data(self):
"""Growing the buffer preserves previously written data."""
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
_RoutedExpertsHostCache,
)

cache = _RoutedExpertsHostCache(
num_hidden_layers=40,
num_experts_per_tok=8,
)
buf = cache.get_or_grow_buffer("req1", max_pos=50)
buf[0, 0, 0] = 42
buf2 = cache.get_or_grow_buffer("req1", max_pos=200)
assert buf2[0, 0, 0] == 42, "Data lost during buffer grow"

def test_free_request_removes_buffer(self):
"""Freeing a request removes its buffer."""
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
_RoutedExpertsHostCache,
)

cache = _RoutedExpertsHostCache(
num_hidden_layers=40,
num_experts_per_tok=8,
)
cache.get_or_grow_buffer("req1", max_pos=50)
cache.free_request("req1")
assert cache.get_buffer("req1") is None


class TestMonolithicWritesFlag:
"""Test that quant methods correctly declare routing replay capability."""

def test_fp8_has_flag(self):
from vllm.model_executor.layers.quantization.fp8 import Fp8MoEMethod

assert getattr(Fp8MoEMethod, "_monolithic_writes_routing_replay", False)

def test_base_class_no_flag(self):
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)

assert not getattr(
FusedMoEMethodBase, "_monolithic_writes_routing_replay", False
)
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/chat_completion/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):
# not part of the OpenAI spec but is useful for tracing the tokens
# in agent scenarios
token_ids: list[int] | None = None
routed_experts: list[list[list[int]]] | None = None # [seq_len, num_layers, top_k]


class ChatCompletionResponse(OpenAIBaseModel):
Expand Down
10 changes: 10 additions & 0 deletions vllm/entrypoints/openai/chat_completion/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,6 +1283,11 @@ async def chat_completion_full_generator(
token_ids=(
as_list(output.token_ids) if request.return_token_ids else None
),
routed_experts=(
output.routed_experts.tolist()
if output.routed_experts is not None
else None
),
)
choices.append(choice_data)
continue
Expand Down Expand Up @@ -1482,6 +1487,11 @@ async def chat_completion_full_generator(
token_ids=(
as_list(output.token_ids) if request.return_token_ids else None
),
routed_experts=(
output.routed_experts.tolist()
if output.routed_experts is not None
else None
),
)
choice_data = maybe_filter_parallel_tool_calls(choice_data, request)

Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/completion/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ class CompletionResponseChoice(OpenAIBaseModel):
token_ids: list[int] | None = None # For response
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
prompt_token_ids: list[int] | None = None # For prompt
routed_experts: list[list[list[int]]] | None = None # [seq_len, num_layers, top_k]


class CompletionResponse(OpenAIBaseModel):
Expand Down
5 changes: 5 additions & 0 deletions vllm/entrypoints/openai/completion/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,11 @@ def request_output_to_completion_response(
token_ids=(
as_list(output.token_ids) if request.return_token_ids else None
),
routed_experts=(
output.routed_experts.tolist()
if output.routed_experts is not None
else None
),
)
choices.append(choice_data)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,5 +169,6 @@ def apply_monolithic(
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
routing_replay_out: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
6 changes: 6 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
# --8<-- [start:fused_moe]
@PluggableLayer.register("fused_moe")
class FusedMoE(PluggableLayer):
# Auto-incrementing layer ID for routing replay buffer binding.
_next_moe_layer_id: int = 0
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
Expand Down Expand Up @@ -280,6 +282,10 @@ def __init__(

self._routed_input_transform = routed_input_transform

# Assign unique layer ID for routing replay buffer binding.
self.moe_layer_id = FusedMoE._next_moe_layer_id
FusedMoE._next_moe_layer_id += 1

if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/fused_moe/modular_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,7 @@ def apply(
e_score_correction_bias: torch.Tensor | None = None,
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
routing_replay_out: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Same as apply(), except uses router_logits as opposed
Expand Down
Loading
Loading