Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ class ModelConfig:
graph and always execute the model in eager mode. If False, we will use
CUDA graph and eager execution in hybrid for maximal performance and
flexibility."""
enable_return_routed_experts: bool = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just call it return_routed_experts?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine. However, the verl framework currently uses this parameter when integrating with vLLM, so changing it would require coordinated changes in verl as well.

"""Whether to return routed experts."""
max_logprobs: int = 20
"""Maximum number of log probabilities to return when `logprobs` is
specified in `SamplingParams`. The default value comes the default for the
Expand Down
1 change: 1 addition & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,7 @@ def __str__(self):
f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa
f"quantization={self.model_config.quantization}, "
f"enforce_eager={self.model_config.enforce_eager}, "
f"enable_return_routed_experts={self.model_config.enable_return_routed_experts}, " # noqa
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing async scheduling disable for routed experts feature

Medium Severity

The PR notes explicitly state that async scheduling should be disabled when enable_return_routed_experts=True, but this isn't implemented. The async scheduling logic in VllmConfig.__post_init__ handles various incompatibility cases (PP > 1, speculative decoding, executor backend) but doesn't include any handling for enable_return_routed_experts. Users in the PR discussion are reporting significant latency issues (10X slower), which could be related to async scheduling interference with the capture/save operations for routed experts.

Fix in Cursor Fix in Web

f"kv_cache_dtype={self.cache_config.cache_dtype}, "
f"device_config={self.device_config.device}, "
f"structured_outputs_config={self.structured_outputs_config!r}, "
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ class EngineArgs:
"""Arguments for vLLM engine."""

model: str = ModelConfig.model
enable_return_routed_experts: bool = ModelConfig.enable_return_routed_experts
model_weights: str = ModelConfig.model_weights
served_model_name: str | list[str] | None = ModelConfig.served_model_name
tokenizer: str | None = ModelConfig.tokenizer
Expand Down Expand Up @@ -657,6 +658,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
**model_kwargs["allow_deprecated_quantization"],
)
model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"])
model_group.add_argument(
"--enable-return-routed-experts",
**model_kwargs["enable_return_routed_experts"],
)
model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"])
model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"])
model_group.add_argument(
Expand Down Expand Up @@ -1239,6 +1244,7 @@ def create_model_config(self) -> ModelConfig:
quantization=self.quantization,
allow_deprecated_quantization=self.allow_deprecated_quantization,
enforce_eager=self.enforce_eager,
enable_return_routed_experts=self.enable_return_routed_experts,
max_logprobs=self.max_logprobs,
logprobs_mode=self.logprobs_mode,
disable_sliding_window=self.disable_sliding_window,
Expand Down
3 changes: 3 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class LLM:
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
enable_return_routed_experts: Whether to return routed experts.
disable_custom_all_reduce: See
[ParallelConfig][vllm.config.ParallelConfig].
hf_token: The token to use as HTTP bearer authorization for remote files
Expand Down Expand Up @@ -209,6 +210,7 @@ def __init__(
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: bool = False,
enable_return_routed_experts: bool = False,
disable_custom_all_reduce: bool = False,
hf_token: bool | str | None = None,
hf_overrides: HfOverrides | None = None,
Expand Down Expand Up @@ -317,6 +319,7 @@ def _make_config(value: Any, cls: type[_R]) -> _R:
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,
enable_return_routed_experts=enable_return_routed_experts,
disable_custom_all_reduce=disable_custom_all_reduce,
hf_token=hf_token,
hf_overrides=hf_overrides,
Expand Down
22 changes: 22 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
)
from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
RoutedExpertsCapturer,
)
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
Expand Down Expand Up @@ -701,6 +704,13 @@ def maybe_init_modular_kernel(self) -> None:
def shared_experts(self) -> torch.nn.Module | None:
return None

@property
def layer_id(self):
# Delayed import to avoid circular dependency
from vllm.model_executor.models.utils import extract_layer_index

return extract_layer_index(self.layer_name)

@property
def gate(self) -> torch.nn.Module | None:
return None
Expand Down Expand Up @@ -1650,6 +1660,18 @@ def valid_grouping() -> bool:

assert topk_ids.dtype == indices_type or indices_type is None

if (
self.vllm_config.model_config is not None
and self.vllm_config.model_config.enable_return_routed_experts
):
# In dummy runs, the capturer is not initialized.
capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None: # in dummmy_run may be None
capturer.capture( # noqa
layer_id=self.layer_id,
topk_ids=topk_ids,
)

return topk_weights, topk_ids

def must_reduce_shared_expert_outputs(self) -> bool:
Expand Down
Loading