Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions csrc/moe/topk_softmax_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ struct alignas(Alignment) AlignedArray {
template <typename T>
__device__ __forceinline__ float toFloat(T value) {
if constexpr (std::is_same_v<T, float>) {
return value;
return fmaxf(value, -FLT_MAX);
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
return __bfloat162float(value);
return fmaxf(__bfloat162float(value), -FLT_MAX);
} else if constexpr (std::is_same_v<T, __half>) {
return __half2float(value);
return fmaxf(__half2float(value), -FLT_MAX);
}
}

Expand Down Expand Up @@ -390,6 +390,11 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
}
}

#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
row_chunk[ii] = fmaxf(row_chunk[ii], -FLT_MAX);
}

if constexpr (SF == SCORING_SOFTMAX) {
// First, we perform a max reduce within the thread.
float thread_max = row_chunk[0];
Expand Down
171 changes: 168 additions & 3 deletions tests/kernels/moe/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,18 +171,52 @@ def assert_aiter_routing_valid(
)


def assert_topk_expert_ids_distinct_and_in_range(
topk_ids: torch.Tensor,
top_k: int,
num_experts: int,
) -> None:
"""Each token row must have K distinct expert indices in ``[0, num_experts)``.

Used when validating fused routing on degenerate inputs (e.g. a token whose
router logits are all NaN). **Duplicate indices in that case are a bug**, not an
acceptable outcome.
"""
n_tokens = topk_ids.shape[0]
assert topk_ids.shape == (n_tokens, top_k), (
f"topk_ids shape {topk_ids.shape} != ({n_tokens}, {top_k})"
)
assert (topk_ids >= 0).all() and (topk_ids < num_experts).all(), (
f"expert IDs out of range [0, {num_experts}): "
f"min={topk_ids.min().item()}, max={topk_ids.max().item()}"
)
for i in range(n_tokens):
ids = topk_ids[i]
assert ids.unique().numel() == top_k, (
f"token {i}: expected {top_k} distinct expert IDs, got {ids.tolist()}"
)


def baseline_fused_topk(
router_logits: torch.Tensor, top_k: int, renormalize: bool
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
scoring_func: str = "softmax",
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Baseline for standard fused top-k routing.

Algorithm:
1. Apply softmax to router logits
1. Apply scoring function (softmax or sigmoid) to router logits
2. Select top-k experts
3. Optionally renormalize the weights
"""
scores = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
if scoring_func == "softmax":
scores = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
elif scoring_func == "sigmoid":
scores = torch.sigmoid(router_logits.float())
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
# Use sorted=False to match vllm implementation (vllm_is_batch_invariant
# defaults to False)
topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1, sorted=False)
Expand Down Expand Up @@ -735,3 +769,134 @@ def test_eplb_map_with_redundancy(
torch.testing.assert_close(load, exp_load)
else:
assert load.sum().item() == 0


@pytest.mark.skipif(
not current_platform.is_cuda(), reason="NaN behavior is tested on CUDA only"
)
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("renormalize", [False, True])
@pytest.mark.parametrize("top_k", [2, 4])
def test_topk_nan_row_distinct_experts(
scoring_func: str,
dtype: torch.dtype,
renormalize: bool,
top_k: int,
):
"""Fused top-k routing must yield K distinct experts per token.

Rows where all router logits are NaN must still route to K distinct experts;
duplicate IDs indicate a bug in the fused routing implementation.
"""
num_experts = 16
m = 8
k = 64
nan_rows = (1, 5)

eplb_state = setup_eplb_state(False, num_experts)
router = create_fused_moe_router(
top_k=top_k,
global_num_experts=num_experts,
renormalize=renormalize,
enable_eplb=False,
eplb_state=eplb_state,
scoring_func=scoring_func,
)

hidden_states = torch.randn((m, k), device="cuda", dtype=dtype) / 10
router_logits = torch.randn((m, num_experts), device="cuda", dtype=dtype)
for r in nan_rows:
router_logits[r].fill_(float("nan"))

topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)

assert_topk_expert_ids_distinct_and_in_range(topk_ids, top_k, num_experts)
assert topk_weights.shape == (m, top_k)

finite_row_mask = torch.ones(m, dtype=torch.bool, device="cuda")
finite_row_mask[list(nan_rows)] = False
baseline_weights, baseline_ids = baseline_fused_topk(
router_logits[finite_row_mask],
top_k,
renormalize,
scoring_func=scoring_func,
)
assert_routing_results_close(
topk_weights[finite_row_mask],
topk_ids[finite_row_mask],
baseline_weights,
baseline_ids,
)


@pytest.mark.skipif(
not current_platform.is_cuda(), reason="NaN behavior is tested on CUDA only"
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("renormalize", [False, True])
def test_grouped_topk_nan_row_distinct_experts(
dtype: torch.dtype,
scoring_func: str,
renormalize: bool,
):
"""Grouped top-k routing must yield K distinct experts per token on all-NaN rows.

Duplicate indices when a token's router logits are all NaN indicate a bug in
the grouped routing implementation.
"""
num_experts = 64
num_expert_group = 8
topk_group = 4
top_k = 4
m = 8
k = 256
routed_scaling_factor = 1.0

eplb_state = setup_eplb_state(False, num_experts)
e_score_correction_bias = make_e_score_correction_bias(0.9, num_experts)

router = create_fused_moe_router(
use_grouped_topk=True,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
top_k=top_k,
global_num_experts=num_experts,
renormalize=renormalize,
enable_eplb=False,
eplb_state=eplb_state,
)

hidden_states = torch.randn((m, k), device="cuda", dtype=dtype) / 10
router_logits = torch.randn((m, num_experts), device="cuda", dtype=dtype)
router_logits[0].fill_(float("nan"))
router_logits[3].fill_(float("nan"))

topk_weights, topk_ids = router.select_experts(hidden_states, router_logits)

assert_topk_expert_ids_distinct_and_in_range(topk_ids, top_k, num_experts)
assert topk_weights.shape == (m, top_k)

finite_row_mask = torch.ones(m, dtype=torch.bool, device="cuda")
finite_row_mask[0] = False
finite_row_mask[3] = False
baseline_weights, baseline_ids = baseline_grouped_topk(
router_logits[finite_row_mask],
top_k,
num_expert_group,
topk_group,
scoring_func,
renormalize,
e_score_correction_bias,
routed_scaling_factor,
)
assert_routing_results_close(
topk_weights[finite_row_mask],
topk_ids[finite_row_mask],
baseline_weights,
baseline_ids,
)
6 changes: 0 additions & 6 deletions vllm/model_executor/layers/fused_moe/oracle/unquantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,6 @@ def _move_to_back(
UnquantizedMoeBackend.BATCHED_TRITON,
]

# HACK: Qwen3.5 has crash with FLASHINFER_CUTLASS BF16 if DEP.
# Updating the oracle querying logic is out of the scope of this
# PR. Need to fix the kernel or update structure in follow up.
if moe_config.moe_parallel_config.dp_size > 1:
_move_to_back(_AVAILABLE_BACKENDS, UnquantizedMoeBackend.FLASHINFER_CUTLASS)

elif current_platform.is_xpu():
_AVAILABLE_BACKENDS = [UnquantizedMoeBackend.XPU]
elif current_platform.is_cpu():
Expand Down
Loading