fix(sm12x): fix micro-kernel workspace sizing when routed_rows > num_local_experts#3191
Conversation
…local_experts Two bugs in the SM12x b12x MoE micro-kernel path (triggered when routed_rows <= micro_cutover, ~20-40): 1. allocate_sm120_static_workspace: compact_topk_ids was sized state_E, but the micro-kernel path passes flat_ids of length routed_rows (= num_tokens * num_topk), which can exceed num_local_experts (state_E) for small batch sizes. Fix: size as max(state_E, max_rows). 2. compact_topk_ids (triton_compact.py): validation check required weight_expert_ids.numel() >= total_pairs. This was wrong — the kernel writes to weight_expert_ids at indices 0..active_expert_count-1, which is bounded by state_E (num local experts), not total_pairs. The check incorrectly rejected valid calls where total_pairs > state_E. Fix: remove the check. Together these caused an assertion failure whenever num_tokens * num_topk > num_local_experts at micro-kernel batch sizes (e.g. 2 tokens * 8 topk = 16 pairs but only 8 local experts). Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
cc: @bkryu -- can you please review this PR? |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThe pre-pass compaction backing storage for ChangesMoE micro-kernel / compaction flow
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request adjusts the allocation size for compact_topk_ids in the MoE dispatch logic and removes an incorrect validation check in the Triton compacting function. Feedback was provided to improve the clarity of a code comment by avoiding the use of a variable name that is not defined within the local scope.
Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace state_E (not in scope here) with 'the number of local experts'. Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
🧹 Nitpick comments (1)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py (1)
142-142: Refresh the stale field-level comment.The dataclass annotation still says
# [state_E] int32, for micro kernel pre-pass, but the buffer is now sized tomax(state_E, max_rows). Worth updating to avoid confusing future readers about the actual capacity contract.📝 Suggested update
- compact_topk_ids: torch.Tensor # [state_E] int32, for micro kernel pre-pass + compact_topk_ids: torch.Tensor # [max(state_E, max_rows)] int32, for micro kernel pre-pass🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` at line 142, Update the stale field comment for the dataclass field compact_topk_ids to reflect its current capacity contract: note that the buffer is sized to max(state_E, max_rows) (still int32) and clarify it's used for the micro-kernel pre-pass / compact top-k indices; locate the compact_topk_ids declaration in moe_dispatch.py and replace the old “[state_E] int32, for micro kernel pre-pass” comment with a concise comment such as “# [max(state_E, max_rows)] int32, for micro-kernel pre-pass (compact top-k indices)”.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Line 142: Update the stale field comment for the dataclass field
compact_topk_ids to reflect its current capacity contract: note that the buffer
is sized to max(state_E, max_rows) (still int32) and clarify it's used for the
micro-kernel pre-pass / compact top-k indices; locate the compact_topk_ids
declaration in moe_dispatch.py and replace the old “[state_E] int32, for micro
kernel pre-pass” comment with a concise comment such as “# [max(state_E,
max_rows)] int32, for micro-kernel pre-pass (compact top-k indices)”.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 060e40a7-e392-4ab9-a9ef-eb8b9cf351c6
📒 Files selected for processing (1)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py`:
- Around line 71-72: The docstring describing the expected size for
weight_expert_ids is stale: it currently claims weight_expert_ids must be length
>= total_pairs but the code now only writes indices 0..active_expert_count-1
bounded by the number of local experts. Update the docstring near the
weight_expert_ids parameter (the docstring around line ~62 in triton_compact.py)
to state that weight_expert_ids needs to be sized to accommodate
active_expert_count (or the number of local experts) rather than total_pairs,
and clarify that no full total_pairs-sized buffer is required because writes are
limited to 0..active_expert_count-1.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 671bfcab-f71e-4f41-aab4-30ec6f6e645b
📒 Files selected for processing (1)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py
| # weight_expert_ids writes at indices 0..active_expert_count-1 (bounded by | ||
| # the number of local experts, not total_pairs), so no size check is needed here. |
There was a problem hiding this comment.
Docstring contract is now stale after removing the size check.
Line 62 still says weight_expert_ids must be [>=total_pairs], but Lines 71-72 explicitly relax that. Please update the docstring to reflect the new expected sizing contract.
Suggested doc fix
- weight_expert_ids: [>=total_pairs] int32 — output: local->global map.
+ weight_expert_ids: int32 — output: local->global map; size must cover
+ the maximum number of unique experts expected in `topk_ids`.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py` around lines
71 - 72, The docstring describing the expected size for weight_expert_ids is
stale: it currently claims weight_expert_ids must be length >= total_pairs but
the code now only writes indices 0..active_expert_count-1 bounded by the number
of local experts. Update the docstring near the weight_expert_ids parameter (the
docstring around line ~62 in triton_compact.py) to state that weight_expert_ids
needs to be sized to accommodate active_expert_count (or the number of local
experts) rather than total_pairs, and clarify that no full total_pairs-sized
buffer is required because writes are limited to 0..active_expert_count-1.
…um_local_experts Adds test_micro_pairs_exceed_local_experts with three configurations where num_tokens * top_k > num_local_experts, directly exercising the bug fixed by sizing compact_topk_ids as max(state_E, max_rows). Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/moe/test_b12x_fused_moe.py`:
- Around line 1070-1140: The new regression test function
test_micro_pairs_exceed_local_experts in the tests/moe/test_b12x_fused_moe.py
file was auto-reformatted by ruff-format in CI; commit the formatter's output so
the repo matches CI (prevent pre-commit/CI failures). Run ruff-format (or your
project's formatting command) on the changed hunk containing
test_micro_pairs_exceed_local_experts and add the resulting changes to the same
patch/PR, then re-run tests to ensure the committed formatting is the only
change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a81dfaeb-9bb9-4079-a832-e8c5691c6723
📒 Files selected for processing (1)
tests/moe/test_b12x_fused_moe.py
|
/bot run |
|
pls address pre-commit |
Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tests/moe/test_b12x_fused_moe.py (1)
1070-1077:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winOutstanding pre-commit formatting failure — ruff-format still needs to be committed.
A previous CI run shows
ruff-formatauto-modified this exact hunk and the reformatted version was not committed. This will continue to fail the pre-commit gate on every subsequent CI run until the formatter output is checked in.Run
pre-commit run --all-files(orruff format tests/moe/test_b12x_fused_moe.py) locally and commit the result.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/moe/test_b12x_fused_moe.py` around lines 1070 - 1077, The failing pre-commit check is due to uncommitted changes from ruff-format for the parametrized test decorator; run the formatter and commit the updated hunk. Locally run `pre-commit run --all-files` or `ruff format tests/moe/test_b12x_fused_moe.py` to apply ruff formatting to the pytest.mark.parametrize block (the decorator specifying "num_tokens,top_k,num_experts"), verify the diff, and commit the formatted file so the pre-commit gate passes.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/moe/test_b12x_fused_moe.py`:
- Around line 1073-1076: The test case tuple (4, 8, 16) in the parameter list
can produce routed_rows=32 which may not hit the micro-kernel path; replace that
tuple with (2, 8, 16) so routed_rows becomes 16 (< typical micro_cutover) to
reliably exercise the micro-kernel path in the test (update the tuple in the
list of test cases in test_b12x_fused_moe.py).
---
Duplicate comments:
In `@tests/moe/test_b12x_fused_moe.py`:
- Around line 1070-1077: The failing pre-commit check is due to uncommitted
changes from ruff-format for the parametrized test decorator; run the formatter
and commit the updated hunk. Locally run `pre-commit run --all-files` or `ruff
format tests/moe/test_b12x_fused_moe.py` to apply ruff formatting to the
pytest.mark.parametrize block (the decorator specifying
"num_tokens,top_k,num_experts"), verify the diff, and commit the formatted file
so the pre-commit gate passes.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 52077f96-6d43-45a9-b275-702ce81d6c77
📒 Files selected for processing (1)
tests/moe/test_b12x_fused_moe.py
| (2, 8, 8), # total_pairs=16 > num_local_experts=8 | ||
| (4, 8, 16), # total_pairs=32 > num_local_experts=16 | ||
| (4, 4, 8), # total_pairs=16 > num_local_experts=8 | ||
| ], |
There was a problem hiding this comment.
Test case (4, 8, 16) may not reliably exercise the micro-kernel path.
routed_rows = num_tokens * top_k = 4 * 8 = 32. The micro-kernel is selected only when routed_rows <= micro_cutover, which the PR describes as "typically 20–40". If micro_cutover is 20 for the target hardware/configuration, this case silently falls through to the standard path and does not exercise the regression being fixed.
Consider replacing (4, 8, 16) with a case whose routed_rows is safely within the guaranteed micro cutover range — e.g. (2, 8, 16) gives routed_rows=16 < 20.
🔧 Suggested replacement
- (4, 8, 16), # total_pairs=32 > num_local_experts=16
+ (2, 8, 16), # total_pairs=16 > num_local_experts=16, routed_rows safely within micro cutover📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| (2, 8, 8), # total_pairs=16 > num_local_experts=8 | |
| (4, 8, 16), # total_pairs=32 > num_local_experts=16 | |
| (4, 4, 8), # total_pairs=16 > num_local_experts=8 | |
| ], | |
| (2, 8, 8), # total_pairs=16 > num_local_experts=8 | |
| (2, 8, 16), # total_pairs=16 > num_local_experts=16, routed_rows safely within micro cutover | |
| (4, 4, 8), # total_pairs=16 > num_local_experts=8 | |
| ], |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/moe/test_b12x_fused_moe.py` around lines 1073 - 1076, The test case
tuple (4, 8, 16) in the parameter list can produce routed_rows=32 which may not
hit the micro-kernel path; replace that tuple with (2, 8, 16) so routed_rows
becomes 16 (< typical micro_cutover) to reliably exercise the micro-kernel path
in the test (update the tuple in the list of test cases in
test_b12x_fused_moe.py).
Summary
Two bugs in the
b12x_fused_moemicro-kernel path (SM120/SM121, triggered whenrouted_rows <= micro_cutover, typically 20–40):allocate_sm120_static_workspace(moe_dispatch.py):compact_topk_idswas sizedstate_E(num local experts), but the micro-kernel fills it withflat_idsof lengthrouted_rows = num_tokens * num_topk. Whennum_tokens * num_topk > num_local_experts(e.g. 2 tokens × 8 topk = 16 pairs, 8 local experts), this caused an assertion failure:compact_topk_ids buffer too small: 8 < 16. Fix: size asmax(state_E, max_rows).compact_topk_ids(triton_compact.py): validation requiredweight_expert_ids.numel() >= total_pairs. This was wrong — the Triton kernel writes toweight_expert_idsonly at indices0..active_expert_count-1, bounded bystate_E(numunique active experts), not
total_pairs. The check rejected valid calls wheretotal_pairs > state_E. Fix: remove the check (with explanatory comment).Both bugs surface together whenever the batch is small enough to hit the micro-kernel path but
num_tokens * num_topk > num_local_experts.Test plan
tests/kernels/moe/test_flashinfer_b12x_moe.pycover the small-batch micro-kernel path — 24/24 pass with this fix on DGX Spark (SM121)Summary by CodeRabbit
Bug Fixes
Documentation
Tests