Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ def allocate_sm120_static_workspace(
active_expert_count=torch.zeros(1, dtype=torch.int32, device=device),
weight_expert_ids=torch.arange(state_E, dtype=torch.int32, device=device),
global_to_local_expert=torch.empty(weight_E, dtype=torch.int32, device=device),
compact_topk_ids=torch.empty(state_E, dtype=torch.int32, device=device),
compact_topk_ids=torch.empty(
max(state_E, max_rows), dtype=torch.int32, device=device
),
)

# Finalize views
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def compact_topk_ids(
return
if compact_topk_ids.numel() < total_pairs:
raise ValueError("compact_topk_ids must have at least total_pairs elements")
if weight_expert_ids.numel() < total_pairs:
raise ValueError("weight_expert_ids must have at least total_pairs elements")
# weight_expert_ids writes at indices 0..active_expert_count-1 (bounded by
# the number of local experts, not total_pairs), so no size check is needed here.
Comment on lines +71 to +72
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring contract is now stale after removing the size check.

Line 62 still says weight_expert_ids must be [>=total_pairs], but Lines 71-72 explicitly relax that. Please update the docstring to reflect the new expected sizing contract.

Suggested doc fix
-        weight_expert_ids: [>=total_pairs] int32 — output: local->global map.
+        weight_expert_ids: int32 — output: local->global map; size must cover
+            the maximum number of unique experts expected in `topk_ids`.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py` around lines
71 - 72, The docstring describing the expected size for weight_expert_ids is
stale: it currently claims weight_expert_ids must be length >= total_pairs but
the code now only writes indices 0..active_expert_count-1 bounded by the number
of local experts. Update the docstring near the weight_expert_ids parameter (the
docstring around line ~62 in triton_compact.py) to state that weight_expert_ids
needs to be sized to accommodate active_expert_count (or the number of local
experts) rather than total_pairs, and clarify that no full total_pairs-sized
buffer is required because writes are limited to 0..active_expert_count-1.

if active_expert_count.numel() != 1:
raise ValueError("active_expert_count must have shape [1]")

Expand Down
72 changes: 72 additions & 0 deletions tests/moe/test_b12x_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,78 @@ def test_micro_single_token_unique_path(self):
f"(atol={atol:.4f})"
)

@pytest.mark.parametrize(
"num_tokens,top_k,num_experts",
[
(2, 8, 8), # total_pairs=16 > num_local_experts=8
(4, 8, 16), # total_pairs=32 > num_local_experts=16
(4, 4, 8), # total_pairs=16 > num_local_experts=8
],
Comment on lines +1073 to +1076
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Test case (4, 8, 16) may not reliably exercise the micro-kernel path.

routed_rows = num_tokens * top_k = 4 * 8 = 32. The micro-kernel is selected only when routed_rows <= micro_cutover, which the PR describes as "typically 20–40". If micro_cutover is 20 for the target hardware/configuration, this case silently falls through to the standard path and does not exercise the regression being fixed.

Consider replacing (4, 8, 16) with a case whose routed_rows is safely within the guaranteed micro cutover range — e.g. (2, 8, 16) gives routed_rows=16 < 20.

🔧 Suggested replacement
-            (4, 8, 16),  # total_pairs=32 > num_local_experts=16
+            (2, 8, 16),  # total_pairs=16 > num_local_experts=16, routed_rows safely within micro cutover
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
(2, 8, 8), # total_pairs=16 > num_local_experts=8
(4, 8, 16), # total_pairs=32 > num_local_experts=16
(4, 4, 8), # total_pairs=16 > num_local_experts=8
],
(2, 8, 8), # total_pairs=16 > num_local_experts=8
(2, 8, 16), # total_pairs=16 > num_local_experts=16, routed_rows safely within micro cutover
(4, 4, 8), # total_pairs=16 > num_local_experts=8
],
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/moe/test_b12x_fused_moe.py` around lines 1073 - 1076, The test case
tuple (4, 8, 16) in the parameter list can produce routed_rows=32 which may not
hit the micro-kernel path; replace that tuple with (2, 8, 16) so routed_rows
becomes 16 (< typical micro_cutover) to reliably exercise the micro-kernel path
in the test (update the tuple in the list of test cases in
test_b12x_fused_moe.py).

)
def test_micro_pairs_exceed_local_experts(
self, num_tokens: int, top_k: int, num_experts: int
):
"""Regression test: micro kernel when num_tokens * top_k > num_local_experts.

The workspace compact_topk_ids buffer was previously sized state_E
(num_local_experts), but the micro kernel fills it with total_pairs =
num_tokens * top_k. When total_pairs > num_local_experts the assertion
'flat_ids.numel() <= workspace.compact_topk_ids.numel()' fired.

Fixed by sizing compact_topk_ids as max(state_E, max_rows).
"""
from flashinfer import b12x_fused_moe

hidden_size, intermediate_size = 256, 512

tensors = create_moe_tensors(
num_tokens=num_tokens,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_experts=num_experts,
num_local_experts=num_experts,
top_k=top_k,
)

result = b12x_fused_moe(
x=tensors["x_bf16"],
w1_weight=tensors["w1_weight"],
w1_weight_sf=tensors["w1_weight_sf"],
w1_alpha=tensors["w1_alpha"],
fc2_input_scale=tensors["fc2_input_scale"],
w2_weight=tensors["w2_weight"],
w2_weight_sf=tensors["w2_weight_sf"],
w2_alpha=tensors["w2_alpha"],
token_selected_experts=tensors["token_selected_experts"],
token_final_scales=tensors["token_final_scales"],
num_experts=num_experts,
top_k=top_k,
)

assert result.shape == (num_tokens, hidden_size)
assert not torch.isnan(result).any()
assert not torch.isinf(result).any()

ref_output = compute_reference_moe_fp4(
hidden_states=tensors["x_bf16"].float().cuda(),
gemm1_weights=tensors["w1_weight_bf16"].float().cuda(),
gemm2_weights=tensors["w2_weight_bf16"].float().cuda(),
token_selected_experts=tensors["token_selected_experts"],
token_final_scales=tensors["token_final_scales"],
num_tokens=num_tokens,
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
fc2_input_scale=tensors["fc2_input_scale"],
)

passed, percent_within, atol = check_accuracy(result, ref_output)
assert passed, (
f"Micro pairs>experts: {percent_within * 100:.2f}% within tolerance "
f"(atol={atol:.4f}, tokens={num_tokens}, top_k={top_k}, experts={num_experts})"
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.


# =============================================================================
# Test Class: ReLU2 Activation (SM120-only, non-gated)
Expand Down
Loading