Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion tests/model_executor/test_routed_experts_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class DummyRouter(BaseRouter):
@property
def routing_method_type(self) -> RoutingMethodType:
return RoutingMethodType.FUSED_TOPK
return RoutingMethodType.TopK

def _compute_routing(self, hidden_states, router_logits, indices_type):
topk_ids = torch.tensor([[1, 2], [3, 4]], dtype=torch.int64)
Expand Down Expand Up @@ -158,3 +158,26 @@ def capture(self, layer_id, topk_ids):
assert callable(dummy_module.router.capture_fn)
dummy_module.router.capture_fn(torch.tensor([[9, 10]]))
assert len(capturer.calls) == 1


@pytest.mark.parametrize(
"scoring_func,top_k,renormalize,expected",
[
("sigmoid", 1, False, RoutingMethodType.Llama4),
("sigmoid", 2, False, RoutingMethodType.DeepSeekV3),
("softmax", 2, False, RoutingMethodType.Default),
("softmax", 2, True, RoutingMethodType.Renormalize),
],
)
def test_routing_method_type_from_topk_mapping(
scoring_func,
top_k,
renormalize,
expected,
):
assert RoutingMethodType.from_topk(scoring_func, top_k, renormalize) == expected


def test_routing_method_type_from_topk_invalid_scoring_func():
with pytest.raises(ValueError, match="Unsupported scoring function"):
RoutingMethodType.from_topk("none", 1, False)
18 changes: 18 additions & 0 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,24 @@ class RoutingMethodType(IntEnum):
# Unspecified
Unspecified = 8.0

@staticmethod
def from_topk(
scoring_func: str,
top_k: int,
renormalize: bool,
) -> "RoutingMethodType":
if scoring_func == "sigmoid":
return (
RoutingMethodType.Llama4 if top_k == 1 else RoutingMethodType.DeepSeekV3
)
if scoring_func == "softmax":
return (
RoutingMethodType.Renormalize
if renormalize
else RoutingMethodType.Default
)
raise ValueError(f"Unsupported scoring function: {scoring_func}")
Comment on lines +132 to +142
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for sigmoid is too broad. It implicitly maps any top_k value other than 1 to RoutingMethodType.DeepSeekV3. This includes potentially invalid values like 0 or negative numbers, and top_k > 2 which may not be correct for DeepSeekV3. The tests only cover top_k=1 and top_k=2. It would be more robust to explicitly check for supported top_k values and raise an error for unsupported ones.

Suggested change
if scoring_func == "sigmoid":
return (
RoutingMethodType.Llama4 if top_k == 1 else RoutingMethodType.DeepSeekV3
)
if scoring_func == "softmax":
return (
RoutingMethodType.Renormalize
if renormalize
else RoutingMethodType.Default
)
raise ValueError(f"Unsupported scoring function: {scoring_func}")
if scoring_func == "sigmoid":
if top_k == 1:
return RoutingMethodType.Llama4
if top_k == 2:
return RoutingMethodType.DeepSeekV3
elif scoring_func == "softmax":
return (
RoutingMethodType.Renormalize
if renormalize
else RoutingMethodType.Default
)
raise ValueError(
f"Unsupported scoring function '{scoring_func}' or top_k '{top_k}'")



@dataclass
class FusedMoEQuantDesc:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _supports_routing_method(
"""Monolithic kernels need to express router support."""
if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
# NOTE(rob): potentially allow others here. This is a conservative list.
# Default routing is not implemented in FlashInfer TRTLLM.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Renormalize,
Expand All @@ -85,7 +86,6 @@ def _supports_routing_method_bf16(
routing_method: RoutingMethodType,
) -> bool:
return routing_method in [
RoutingMethodType.Default,
RoutingMethodType.Renormalize,
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,8 @@ def __init__(

@property
def routing_method_type(self) -> RoutingMethodType:
return (
RoutingMethodType.Renormalize
if not self.renormalize
else RoutingMethodType.RenormalizeNaive
return RoutingMethodType.from_topk(
self.scoring_func, self.top_k, self.renormalize
)

def _compute_routing(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,8 @@ def __init__(

@property
def routing_method_type(self) -> RoutingMethodType:
return (
RoutingMethodType.Renormalize
if not self.renormalize
else RoutingMethodType.RenormalizeNaive
return RoutingMethodType.from_topk(
self.scoring_func, self.top_k, self.renormalize
)

def _compute_routing(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def flashinfer_trtllm_fp4_moe(
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
routed_scaling_factor=None,
routing_method_type=routing_method_type,
routing_method_type=1,
do_finalize=True,
)[0]

Expand Down