feat: Add b12x_fused_moe / B12xMoEWrapper SM120 APIs with micro kernel and ReLU2#3080
feat: Add b12x_fused_moe / B12xMoEWrapper SM120 APIs with micro kernel and ReLU2#3080bkryu wants to merge 14 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdded gated (SiLU) vs non‑gated (ReLU²) activation support across benchmarks and fused MoE CuTe‑DSL kernels; introduced SM12x b12x functional API and wrapper, a tiny‑decode micro backend with routing compaction, activation‑aware kernel compilation/storage, weight/layout/quantization changes, and new tests. Changes
Sequence Diagram(s)sequenceDiagram
participant Caller as Caller
participant Dispatch as Dispatch Layer
participant Compact as Triton Compact
participant Cache as Kernel Cache
participant Micro as Micro Kernel
participant Static as Static Kernel
participant Dynamic as Dynamic Kernel
participant Activation as Activation Func
Caller->>Dispatch: submit token_selected_experts, weights, scales, activation
Dispatch->>Dispatch: compute routed_rows = num_tokens * top_k
Dispatch->>Dispatch: select backend (micro/static/dynamic) using cutovers
alt Micro Path
Dispatch->>Compact: compact_topk_ids(topk_ids)
Compact-->>Dispatch: compact_ids, active_expert_count, weight_expert_ids
Dispatch->>Cache: lookup/compile micro kernel (activation, mac_override)
Cache-->>Dispatch: micro kernel
Dispatch->>Micro: launch(compact_ids, activation, weights, scales)
Micro->>Activation: apply SiLU or ReLU²
Micro->>Caller: write outputs
else Static Path
Dispatch->>Cache: lookup/compile static kernel (activation)
Cache-->>Dispatch: static kernel
Dispatch->>Static: launch(topk_ids, activation, weights, scales)
Static->>Activation: apply SiLU or ReLU²
Static->>Caller: write outputs
else Dynamic Path
Dispatch->>Dynamic: launch dynamic kernel (activation)
Dynamic->>Activation: apply activation variant
Dynamic->>Caller: write outputs
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request adds support for non-gated MoE activations (ReLU2) and introduces a specialized micro-kernel for small decode batches on Blackwell SM120/SM121 architectures. The changes include updates to the static and dynamic CuTe DSL kernels, a new Triton-based ID compaction pre-pass, and updated benchmarking and testing utilities. Feedback indicates that the "moe_micro_kernel.py" file is missing from the PR and identifies a potential out-of-bounds risk when slicing the workspace buffer for compact IDs.
| from .triton_compact import compact_topk_ids as _triton_compact_topk_ids | ||
|
|
||
| # Run Triton pre-pass to compact global expert IDs to dense local indices | ||
| compact_ids = workspace.compact_topk_ids[: flat_ids.numel()] |
There was a problem hiding this comment.
There is a potential out-of-bounds risk if flat_ids.numel() exceeds state_E. While the micro kernel path is currently restricted to routed_rows <= 40, the workspace allocation for compact_topk_ids uses state_E. It would be safer to ensure that the slice does not exceed the allocated size of the workspace buffer, or add an explicit check.
There was a problem hiding this comment.
Good catch on the defensive coding. In practice this can't overflow — compact_topk_ids is sized [state_E] (typically 256-512) while flat_ids.numel() is at most 40 on the micro path (the cutover threshold). But the invariant should be explicit. Added an assertion in the next commit
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
benchmarks/routines/moe.py (1)
1790-1813:⚠️ Potential issue | 🟡 MinorPass
is_gatedinto the FP8 bandwidth model too.TFLOPS now distinguishes gated vs non-gated activations, but both FP8 bandwidth calls still rely on the default path. ReLU2 runs will therefore report inconsistent bandwidth.
🛠️ Suggested fix
tb_per_sec = calculate_moe_kernel_bandwidth( num_tokens, hidden_size, intermediate_size, num_experts, top_k, median_time, input_dtype, weight_dtype, input_format="fp8", weight_format="fp8", routing_logits_dtype=routing_logits.dtype, active_experts=int(selected_experts.unique().numel()), verbose=args.verbose, + is_gated=args.activation_type in (ActivationType.Swiglu, ActivationType.Geglu), )tb_per_sec = calculate_moe_kernel_bandwidth( num_tokens, hidden_size, intermediate_size, num_experts, top_k, median_time, input_dtype, weight_dtype, input_format="fp8", weight_format="fp8", routing_logits_dtype=routing_logits.dtype, active_experts=int(selected_experts.unique().numel()), verbose=args.verbose, + is_gated=args.activation_type in (ActivationType.Swiglu, ActivationType.Geglu), )Also applies to: 2025-2048
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/moe.py` around lines 1790 - 1813, The FP8 bandwidth call is missing the is_gated flag, causing gated vs non-gated activations to report inconsistent bandwidth; update the calculate_moe_kernel_bandwidth invocation(s) to pass the same is_gated boolean used for calculate_moe_tflops (e.g., is_gated=args.activation_type in (ActivationType.Swiglu, ActivationType.Geglu)) so both calculate_moe_tflops(...) and calculate_moe_kernel_bandwidth(...) receive the same gating hint (apply the same change to the other occurrences around the later block that mirrors this code).flashinfer/fused_moe/cute_dsl/fused_moe.py (1)
362-394:⚠️ Potential issue | 🟠 MajorReject
relu2outside the SM120/SM121 path.These new public parameters are only honored in the SM120 branch. The fallback path still goes through
_moe_core_impl(), which hard-wires the SwiGLU fusion helper, soactivation_type="relu2"on SM100/SM103 can run the wrong math or hit mismatched FC1 shapes instead of failing fast.🛠️ Suggested guard
@@ - self.activation_type = activation_type + if activation_type not in {"silu", "relu2"}: + raise ValueError(f"Unsupported activation_type: {activation_type!r}") + self.activation_type = activation_type @@ major, minor = torch.cuda.get_device_capability(device) self._is_sm120 = major == 12 + if activation_type != "silu" and not self._is_sm120: + raise ValueError( + "activation_type='relu2' is only supported on SM120/SM121" + )def cute_dsl_fused_moe_nvfp4( @@ - if num_local_experts is None: + if activation_type not in {"silu", "relu2"}: + raise ValueError(f"Unsupported activation_type: {activation_type!r}") + + if num_local_experts is None: num_local_experts = num_experts @@ major, _ = torch.cuda.get_device_capability(x.device) if major == 12: ... + elif activation_type != "silu": + raise ValueError( + "activation_type='relu2' is only supported on SM120/SM121" + )Also applies to: 827-916
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` around lines 362 - 394, The constructor accepts activation_type but the non-SM120/SM121 path still calls _moe_core_impl which assumes SwiGLU; add a runtime guard in the initializer (or immediately before dispatch to _moe_core_impl) that checks activation_type and the detected GPU SM version and either raise a clear error or restrict allowed values when SM < 120 (e.g., if activation_type == "relu2" and not on SM120/121, raise ValueError). Update the dispatch code path that calls _moe_core_impl (and any fallback branches referenced around the alternate implementation) to enforce this same check so relu2 is only honored on the SM120/SM121 branch and cannot silently run with the SwiGLU helper.flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py (1)
138-185:⚠️ Potential issue | 🔴 CriticalSize
compact_topk_idsfor routed rows, not experts.Line 824 slices this buffer to
flat_ids.numel()(num_tokens * top_k), but the workspace only allocatesstate_Eentries. Any micro launch with more routed rows than local experts will write past the end of the buffer.🛠️ Suggested fix
- compact_topk_ids: torch.Tensor # [state_E] int32, for micro kernel pre-pass + compact_topk_ids: torch.Tensor # [max_rows] int32, for micro kernel pre-pass @@ - compact_topk_ids=torch.empty(state_E, dtype=torch.int32, device=device), + compact_topk_ids=torch.empty(max_rows, dtype=torch.int32, device=device),🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines 138 - 185, compact_topk_ids is currently allocated with length state_E but is later sliced to flat_ids.numel() (num_tokens * top_k) in the micro-kernel pre-pass, causing out-of-bounds writes when num_tokens > state_E; in allocate_sm120_static_workspace change the compact_topk_ids allocation in Sm120StaticMoEWorkspace to have capacity for the worst-case routed rows times top-k (e.g. torch.empty(state_E * max_rows * num_topk, dtype=torch.int32, device=device) or at minimum torch.empty(max_rows * state_E * num_topk, ...)) so flat_ids.numel() can always fit, and keep references to compact_topk_ids, allocate_sm120_static_workspace, Sm120StaticMoEWorkspace, num_topk, max_rows, and state_E to locate the change.
🧹 Nitpick comments (1)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py (1)
76-86: Add a defensive size guard for micro-only usage.This kernel is O(BLOCK²) in a single program; adding an explicit upper bound makes accidental large launches fail fast with a clear message.
Possible guardrail
block = triton.next_power_of_2(total_pairs) + if block > 256: + raise ValueError( + f"compact_topk_ids is intended for micro batches; got total_pairs={total_pairs}" + ) num_warps = 1 if block <= 16 else 2🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py` around lines 76 - 86, Add a defensive size guard before launching _compact_topk_ids_kernel to prevent accidental large O(BLOCK²) launches: compute block = triton.next_power_of_2(total_pairs) as you do, then check against a small hard limit (e.g. MAX_BLOCK = 64 or 128) and/or a MAX_PAIRS derived limit and raise a clear RuntimeError if block > MAX_BLOCK (include block and total_pairs in the message). Keep the existing num_warps logic and kernel args (_compact_topk_ids_kernel, topk_ids, compact_topk_ids, weight_expert_ids, active_expert_count, total_pairs, BLOCK=block, num_warps=num_warps) unchanged; just insert the guard using the same symbols so oversized launches fail fast with a descriptive error.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/routines/moe.py`:
- Around line 1367-1371: The code silently maps unsupported ActivationType
values to "silu"; change the logic to validate args.activation_type against the
supported mapping instead of defaulting. Use the _ACT_STR dict to look up
activation_str and if the activation_type is not present raise a clear exception
(e.g., ValueError) mentioning the unsupported ActivationType and listing
supported keys; also compute is_gated from ActivationType.Geglu and
ActivationType.Swiglu as before but ensure Geglu is rejected if not in _ACT_STR
so it cannot silently run the SiLU kernel.
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py`:
- Around line 539-541: Comments describing w13 tile ordering conflict with
actual code: update the inline comments around the w13 TMA descriptor creation
(the lines that mention "Gate tiles at N=..." and "up tiles at N=...") to
reflect the actual ordering used by the code path where gate_slice_idx =
intermediate_slice + gate_tile_cnt (i.e., up tiles occupy the first half of
N-tiles and gate tiles the second half). Locate the call to
self._dense_cls._make_tma_atoms_and_tensors that produces tma_b_w13/gB_w13 and
any other similar comments (also around the other occurrence near lines
~1168-1171) and change the wording so it states "Up tiles at N=0..I_tp/tile_N-1,
Gate tiles at N=I_tp/tile_N..2*I_tp/tile_N-1" or equivalent that matches the
gate_slice_idx logic.
In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 1093-1096: Replace the custom skip gating that defines
sm120_cuda13 (which currently uses is_sm120_family() and _has_cuda_13()) with
the repository-standard capability checks from flashinfer.utils or the API
capability method; specifically, remove is_sm120_family()/_has_cuda_13() and use
the appropriate flashinfer.utils helper (e.g., is_sm120_supported() or analogous
is_sm90a_supported()/is_sm100a_supported()) or call
api_name.is_compute_capability_supported(cc) to decide the skip. Ensure the new
marker still uses pytest.mark.skipif(...) with a descriptive reason string
indicating the required SM/CUDA capability.
---
Outside diff comments:
In `@benchmarks/routines/moe.py`:
- Around line 1790-1813: The FP8 bandwidth call is missing the is_gated flag,
causing gated vs non-gated activations to report inconsistent bandwidth; update
the calculate_moe_kernel_bandwidth invocation(s) to pass the same is_gated
boolean used for calculate_moe_tflops (e.g., is_gated=args.activation_type in
(ActivationType.Swiglu, ActivationType.Geglu)) so both calculate_moe_tflops(...)
and calculate_moe_kernel_bandwidth(...) receive the same gating hint (apply the
same change to the other occurrences around the later block that mirrors this
code).
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Around line 138-185: compact_topk_ids is currently allocated with length
state_E but is later sliced to flat_ids.numel() (num_tokens * top_k) in the
micro-kernel pre-pass, causing out-of-bounds writes when num_tokens > state_E;
in allocate_sm120_static_workspace change the compact_topk_ids allocation in
Sm120StaticMoEWorkspace to have capacity for the worst-case routed rows times
top-k (e.g. torch.empty(state_E * max_rows * num_topk, dtype=torch.int32,
device=device) or at minimum torch.empty(max_rows * state_E * num_topk, ...)) so
flat_ids.numel() can always fit, and keep references to compact_topk_ids,
allocate_sm120_static_workspace, Sm120StaticMoEWorkspace, num_topk, max_rows,
and state_E to locate the change.
In `@flashinfer/fused_moe/cute_dsl/fused_moe.py`:
- Around line 362-394: The constructor accepts activation_type but the
non-SM120/SM121 path still calls _moe_core_impl which assumes SwiGLU; add a
runtime guard in the initializer (or immediately before dispatch to
_moe_core_impl) that checks activation_type and the detected GPU SM version and
either raise a clear error or restrict allowed values when SM < 120 (e.g., if
activation_type == "relu2" and not on SM120/121, raise ValueError). Update the
dispatch code path that calls _moe_core_impl (and any fallback branches
referenced around the alternate implementation) to enforce this same check so
relu2 is only honored on the SM120/SM121 branch and cannot silently run with the
SwiGLU helper.
---
Nitpick comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py`:
- Around line 76-86: Add a defensive size guard before launching
_compact_topk_ids_kernel to prevent accidental large O(BLOCK²) launches: compute
block = triton.next_power_of_2(total_pairs) as you do, then check against a
small hard limit (e.g. MAX_BLOCK = 64 or 128) and/or a MAX_PAIRS derived limit
and raise a clear RuntimeError if block > MAX_BLOCK (include block and
total_pairs in the message). Keep the existing num_warps logic and kernel args
(_compact_topk_ids_kernel, topk_ids, compact_topk_ids, weight_expert_ids,
active_expert_count, total_pairs, BLOCK=block, num_warps=num_warps) unchanged;
just insert the guard using the same symbols so oversized launches fail fast
with a descriptive error.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5e3f31a2-4325-47e8-9db8-0acd0b10bef6
📒 Files selected for processing (11)
benchmarks/routines/moe.pybenchmarks/routines/moe_utils.pyflashinfer/cute_dsl/fp4_common.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/__init__.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.pyflashinfer/fused_moe/cute_dsl/fused_moe.pytests/moe/test_cute_dsl_fused_moe.py
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/routines/moe.py`:
- Around line 2030-2037: The bandwidth computation is still using the default
gated accounting while TFLOPS now selects non-gated accounting for some
activations; update the call to calculate_moe_kernel_bandwidth to pass the same
is_gated boolean used for calculate_moe_tflops (i.e.,
is_gated=args.activation_type in (ActivationType.Swiglu, ActivationType.Geglu)
or the expression that produced non-gated for Relu2) so both TFLOPS and kernel
bandwidth use the same activation-aware gating flag (refer to
calculate_moe_tflops and calculate_moe_kernel_bandwidth to locate the calls).
- Around line 1795-1802: The TFLOPS call uses args.activation_type to set
is_gated but this routine still constructs gated FC1 tensors and
run_fp8_block_moe never receives an activation flag, so reported TFLOPS can
diverge from the executed kernel; fix by not switching the is_gated flag based
on args.activation_type here — either hard-code is_gated=True (gated-only path)
or derive is_gated from the same gated-only indicator used when building tensors
(e.g., the 2 * intermediate_size gated FC1 logic) and/or update
run_fp8_block_moe to accept and forward an activation_type so activation-based
toggles are consistent with calculate_moe_tflops; reference
calculate_moe_tflops, run_fp8_block_moe, args.activation_type and the gated FC1
construction (2 * intermediate_size) when making the change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 43a02c6d-a59b-416a-b805-d1e691a5397e
📒 Files selected for processing (6)
benchmarks/routines/moe.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.pytests/moe/test_cute_dsl_fused_moe.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/moe/test_cute_dsl_fused_moe.py
- flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
nv-yunzheq
left a comment
There was a problem hiding this comment.
Generally look good to me. Left a few comments to make sure sm10x cute dsl moe wasn't get changed with the change
| num_local_experts=num_experts_local, | ||
| scatter_output=moe_output, | ||
| ) | ||
| # NOTE: SM120/SM121 dispatch is handled by callers (CuteDslMoEWrapper.run |
There was a problem hiding this comment.
Thanks Yunzhe for catching this. Doing another scan to remove references to sm120/121/b12x in the cute dsl MoE path.
There was a problem hiding this comment.
Here the comments are not correct. Maybe the cleaner for this file is to just to restore the status of before b12x being added
There was a problem hiding this comment.
Ditto as above. Reverting to
x: NVFP4-quantized input [num_tokens, hidden_size // 2].
x_sf: Scale factors for x.
| sm100_required = pytest.mark.skipif( | ||
| not is_sm100_family() or (is_sm120_family() and not _has_cuda_13()), | ||
| reason="Requires SM100/SM103 or SM120/SM121 GPU (SM120 requires CUDA 13+)", | ||
| not is_sm100_family() or is_sm120_family(), |
There was a problem hiding this comment.
Let's also maybe restore this file to the status where b12x being added. I think the logic here is not correct
There was a problem hiding this comment.
Good catch let me revert tests/moe/test_cute_dsl_fused_moe.py
| token_final_scales: torch.Tensor, | ||
| num_experts: int, | ||
| top_k: int, | ||
| num_local_experts: Optional[int] = None, |
There was a problem hiding this comment.
I suggest putting some non-essential args after , *, as keyword-only args
tho output, act etc are common enough. idk about w1 w2 alpha tho, are they like global scales?
There was a problem hiding this comment.
Good point for clarify. Reordering in the next commit.
Regarding your question, yes, the w1 an w2 alphas are per-expert global scales
There was a problem hiding this comment.
(Also, not just reordered but added the * as suggested)
| @@ -66,12 +66,12 @@ def _has_cuda_13(): | |||
| not is_cute_dsl_available(), reason="CuteDSL not available" | |||
| ) | |||
| sm100_required = pytest.mark.skipif( | |||
There was a problem hiding this comment.
will discuss it more in a call for me to catch up understanding
There was a problem hiding this comment.
This is a good point and I should have been more clear. Reverting the entire file to be at the state prior to any b12x work because I am moving b12x moe tests that run on sm12x to a separate test file
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
tests/moe/test_cute_dsl_fused_moe.py (1)
38-56:⚠️ Potential issue | 🟠 MajorUse the repo capability helper here, not a local
major == 10check.This reintroduces custom architecture gating and drops the repo-standard supported-runtime check. On an SM10x environment that the API still considers unsupported, these tests will now run instead of skip and fail for the wrong reason. Please gate
sm100_requiredwithflashinfer.utils.is_sm100a_supported(...)or the API capability method.🛠️ Proposed fix
+from flashinfer.utils import is_sm100a_supported + def is_sm100_family(): - """Check for SM100 family (Blackwell: SM100, SM103). + """Check for a supported SM100/SM103 runtime. CuteDSL MoE NVFP4 kernels on SM10x use cute_dsl_fused_moe_nvfp4 API. SM120/121 tests are in test_b12x_fused_moe.py instead. """ if not torch.cuda.is_available(): return False - props = torch.cuda.get_device_properties(0) - return props.major == 10 + return is_sm100a_supported(torch.cuda.current_device())As per coding guidelines
tests/**/*.py: Skip test execution on unsupported GPU architectures usingflashinfer.utilscheck functions (is_sm90a_supported(),is_sm100a_supported(), etc.) or API methods likeapi_name.is_compute_capability_supported(cc).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_cute_dsl_fused_moe.py` around lines 38 - 56, Replace the custom is_sm100_family() check and the sm100_required skip marker with the repo-standard capability check: remove or stop using is_sm100_family() and instead gate sm100_required using flashinfer.utils.is_sm100a_supported(...) (or the API capability method api_name.is_compute_capability_supported(cc)) so the pytest.mark.skipif uses the repo helper; update references to the skip decorator (sm100_required) to call flashinfer.utils.is_sm100a_supported() and ensure torch.cuda.is_available() logic is handled by that helper rather than checking props.major == 10 in the is_sm100_family function.
🧹 Nitpick comments (1)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py (1)
543-764: Consider extracting shared fake tensor generation.The
_get_micro_kernelfunction duplicates most of the fake tensor generation code from_get_static_kernel(lines 611-730 mirror 385-503). Consider extracting a helper function to reduce duplication and improve maintainability.💡 Example helper extraction
def _make_moe_fake_tensors( state_E: int, weight_E: int, m: int, k: int, w1_rows: int, n: int, num_topk: int, max_rows: int, rows_pad_k: int, cols_pad_k: int, topk_ids_dtype: torch.dtype, ): """Build fake tensors for MoE kernel compilation.""" ab_dtype = cutlass.Float4E2M1FN sf_dtype = cutlass.Float8E4M3FN a_dtype = cutlass.BFloat16 alpha_dtype = cutlass.Float32 # ... shared fake tensor creation ... return (a_input_fake, topk_ids_fake, ..., token_weights_fake)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines 543 - 764, The fake-tensor construction in _get_micro_kernel duplicates the block from _get_static_kernel; extract the shared creation into a helper (e.g., _make_moe_fake_tensors) that accepts the identifying params (state_E, weight_E, m, k, w1_rows, n, num_topk, max_rows, rows_pad_k, cols_pad_k, topk_ids_dtype) and returns the tuple of fake tensors used by both functions (a_input_fake, topk_ids_fake, topk_weights_fake, packed_a_fake, sfa_fake, packed_a_storage_fake, scale_storage_fake, barrier_count_fake, barrier_epoch_fake, b_w13_fake, sfb_w13_fake, b_down_fake, sfb_down_fake, row_counts_fake, active_expert_count_fake, weight_expert_ids_fake, global_to_local_expert_fake, input_gs_fake, alpha_fake, down_alpha_fake, global_scale_fake, scatter_fake, token_map_fake, token_weights_fake); replace the duplicate blocks in _get_micro_kernel and _get_static_kernel to call this helper and use its returned values before compilation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Around line 842-844: The micro_work_tiles computation is over-counting by
multiplying routed_rows by N_tiles; update micro_work_tiles to compute M_tiles =
max(1, (routed_rows + 128 - 1) // 128) and N_tiles = max(1, (n + 128 - 1) //
128) and then set micro_work_tiles = M_tiles * N_tiles (or max(1, M_tiles *
N_tiles) if desired). Keep the rest of the logic (calling
_lookup_mac_ladder(_MICRO_MAC_LADDER, routed_rows) to get tuned_mac and
computing micro_mac = min(tuned_mac or base_mac, micro_work_tiles, base_mac))
unchanged so MAC clamping correctly reflects actual M_tiles × N_tiles
parallelism.
- Line 138: The compact_topk_ids buffer is undersized: change its
allocation/annotation (currently declared as compact_topk_ids: torch.Tensor #
[state_E]) to have capacity for flattened routed pairs (size max_rows *
num_topk, or at minimum max_rows) so flat_ids.numel() (num_tokens * top_k) never
exceeds workspace.compact_topk_ids.numel(); update the workspace struct/creation
sites (search for compact_topk_ids in moe_dispatch.py and the other occurrence
around line 185) to allocate a 1D int32 tensor of length max_rows * num_topk and
adjust any comments/annotations accordingly so the assertion at the
flat_ids.numel() <= workspace.compact_topk_ids.numel() check (around the 827-831
area) passes reliably.
---
Duplicate comments:
In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 38-56: Replace the custom is_sm100_family() check and the
sm100_required skip marker with the repo-standard capability check: remove or
stop using is_sm100_family() and instead gate sm100_required using
flashinfer.utils.is_sm100a_supported(...) (or the API capability method
api_name.is_compute_capability_supported(cc)) so the pytest.mark.skipif uses the
repo helper; update references to the skip decorator (sm100_required) to call
flashinfer.utils.is_sm100a_supported() and ensure torch.cuda.is_available()
logic is handled by that helper rather than checking props.major == 10 in the
is_sm100_family function.
---
Nitpick comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Around line 543-764: The fake-tensor construction in _get_micro_kernel
duplicates the block from _get_static_kernel; extract the shared creation into a
helper (e.g., _make_moe_fake_tensors) that accepts the identifying params
(state_E, weight_E, m, k, w1_rows, n, num_topk, max_rows, rows_pad_k,
cols_pad_k, topk_ids_dtype) and returns the tuple of fake tensors used by both
functions (a_input_fake, topk_ids_fake, topk_weights_fake, packed_a_fake,
sfa_fake, packed_a_storage_fake, scale_storage_fake, barrier_count_fake,
barrier_epoch_fake, b_w13_fake, sfb_w13_fake, b_down_fake, sfb_down_fake,
row_counts_fake, active_expert_count_fake, weight_expert_ids_fake,
global_to_local_expert_fake, input_gs_fake, alpha_fake, down_alpha_fake,
global_scale_fake, scatter_fake, token_map_fake, token_weights_fake); replace
the duplicate blocks in _get_micro_kernel and _get_static_kernel to call this
helper and use its returned values before compilation.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4c52a6fd-d2eb-4171-a08e-39319f27fbbc
📒 Files selected for processing (3)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.pytests/moe/test_cute_dsl_fused_moe.py
| active_expert_count: torch.Tensor # [1] int32 | ||
| weight_expert_ids: torch.Tensor # [state_E] int32 | ||
| global_to_local_expert: torch.Tensor # [weight_E] int32 | ||
| compact_topk_ids: torch.Tensor # [state_E] int32, for micro kernel pre-pass |
There was a problem hiding this comment.
Critical: compact_topk_ids buffer is undersized — will cause assertion failures.
The compact_topk_ids buffer is allocated with size state_E (number of local experts), but it's used to store compacted routing IDs which have size num_tokens * top_k (routed pairs).
At line 827-831, the code asserts flat_ids.numel() <= workspace.compact_topk_ids.numel(), where flat_ids.numel() = num_tokens * top_k. With typical MoE configs (e.g., 8 local experts but up to 40 routed pairs for micro kernel), this assertion will fail.
The buffer should be sized to max_rows * num_topk or at minimum max_rows to accommodate the flattened routing IDs.
🐛 Proposed fix
`@dataclass`(kw_only=True)
class Sm120StaticMoEWorkspace:
"""Scratch buffers for one SM120 static MoE launch."""
...
- compact_topk_ids: torch.Tensor # [state_E] int32, for micro kernel pre-pass
+ compact_topk_ids: torch.Tensor # [max_rows * num_topk] int32, for micro kernel pre-pass workspace = Sm120StaticMoEWorkspace(
...
- compact_topk_ids=torch.empty(state_E, dtype=torch.int32, device=device),
+ compact_topk_ids=torch.empty(max_rows * num_topk, dtype=torch.int32, device=device),
)Also applies to: 185-185
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` at line 138,
The compact_topk_ids buffer is undersized: change its allocation/annotation
(currently declared as compact_topk_ids: torch.Tensor # [state_E]) to have
capacity for flattened routed pairs (size max_rows * num_topk, or at minimum
max_rows) so flat_ids.numel() (num_tokens * top_k) never exceeds
workspace.compact_topk_ids.numel(); update the workspace struct/creation sites
(search for compact_topk_ids in moe_dispatch.py and the other occurrence around
line 185) to allocate a 1D int32 tensor of length max_rows * num_topk and adjust
any comments/annotations accordingly so the assertion at the flat_ids.numel() <=
workspace.compact_topk_ids.numel() check (around the 827-831 area) passes
reliably.
| micro_work_tiles = max(1, routed_rows * max(1, (n + 128 - 1) // 128)) | ||
| tuned_mac = _lookup_mac_ladder(_MICRO_MAC_LADDER, routed_rows) | ||
| micro_mac = min(tuned_mac or base_mac, micro_work_tiles, base_mac) |
There was a problem hiding this comment.
Incorrect work tile count calculation may cause suboptimal MAC selection.
The formula routed_rows * max(1, (n + 128 - 1) // 128) multiplies raw routed rows by N-dimension tile count, but should instead compute the actual number of work tiles as M_tiles × N_tiles.
This over-estimates the work tile count (e.g., for routed_rows=20, n=4096: current = 20 * 32 = 640 vs correct = ceil(20/128) * 32 = 32), which may prevent MAC from being properly clamped to available parallelism.
🔧 Proposed fix
- micro_work_tiles = max(1, routed_rows * max(1, (n + 128 - 1) // 128))
+ m_tiles = max(1, (routed_rows + 127) // 128)
+ n_tiles = max(1, (n + 127) // 128)
+ micro_work_tiles = m_tiles * n_tiles📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| micro_work_tiles = max(1, routed_rows * max(1, (n + 128 - 1) // 128)) | |
| tuned_mac = _lookup_mac_ladder(_MICRO_MAC_LADDER, routed_rows) | |
| micro_mac = min(tuned_mac or base_mac, micro_work_tiles, base_mac) | |
| m_tiles = max(1, (routed_rows + 127) // 128) | |
| n_tiles = max(1, (n + 127) // 128) | |
| micro_work_tiles = m_tiles * n_tiles | |
| tuned_mac = _lookup_mac_ladder(_MICRO_MAC_LADDER, routed_rows) | |
| micro_mac = min(tuned_mac or base_mac, micro_work_tiles, base_mac) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines
842 - 844, The micro_work_tiles computation is over-counting by multiplying
routed_rows by N_tiles; update micro_work_tiles to compute M_tiles = max(1,
(routed_rows + 128 - 1) // 128) and N_tiles = max(1, (n + 128 - 1) // 128) and
then set micro_work_tiles = M_tiles * N_tiles (or max(1, M_tiles * N_tiles) if
desired). Keep the rest of the logic (calling
_lookup_mac_ladder(_MICRO_MAC_LADDER, routed_rows) to get tuned_mac and
computing micro_mac = min(tuned_mac or base_mac, micro_work_tiles, base_mac))
unchanged so MAC clamping correctly reflects actual M_tiles × N_tiles
parallelism.
|
/bot run |
|
/bot stop |
|
/bot run |
|
The GitLab CI pipeline #48708465 has been cancelled. |
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (3)
tests/moe/test_cute_dsl_fused_moe.py (1)
38-56:⚠️ Potential issue | 🟡 MinorUse
flashinfer.utils.is_sm100a_supported()for the skip gate.This reintroduces a custom SM10x check and drops the repo-standard CUDA-version guard. The helper in
flashinfer.utilsalready matches the intended SM100/SM103 coverage, so using it keeps the test gate consistent with the rest of the suite.🔧 Proposed fix
import pytest import torch from torch.nn import functional as F from flashinfer.cute_dsl import is_cute_dsl_available +from flashinfer.utils import is_sm100a_supported - - -def is_sm100_family(): - """Check for SM100 family (Blackwell: SM100, SM103). - - CuteDSL MoE NVFP4 kernels are optimized for SM10x architecture. - """ - if not torch.cuda.is_available(): - return False - props = torch.cuda.get_device_properties(0) - return props.major == 10 # Skip decorators cute_dsl_available = pytest.mark.skipif( not is_cute_dsl_available(), reason="CuteDSL not available" ) sm100_required = pytest.mark.skipif( - not is_sm100_family(), - reason="Requires SM100 family GPU (Blackwell: SM100, SM103, SM110)", + not is_sm100a_supported(torch.device("cuda")), + reason="Requires supported SM10x GPU", )As per coding guidelines:
tests/**/*.py: Skip test execution on unsupported GPU architectures usingflashinfer.utilscheck functions (is_sm90a_supported(),is_sm100a_supported(), etc.) or API methods likeapi_name.is_compute_capability_supported(cc).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_cute_dsl_fused_moe.py` around lines 38 - 56, Replace the custom is_sm100_family() check with the repo-standard helper: import and use flashinfer.utils.is_sm100a_supported() for the skip gate; specifically remove or stop using the local is_sm100_family() and change the sm100_required pytest.mark.skipif(...) to call flashinfer.utils.is_sm100a_supported(), ensuring the test uses the common utility (and add the import if missing).flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py (2)
138-138:⚠️ Potential issue | 🔴 CriticalSize
compact_topk_idsfor routed pairs, not experts.This buffer is still allocated as
[state_E], but the micro pre-pass consumes it as a flattened routing-id buffer of lengthnum_tokens * top_k. That will trip the assertion at Line 827 for small-expert configs where routed pairs exceed local experts.🐛 Proposed fix
- compact_topk_ids: torch.Tensor # [state_E] int32, for micro kernel pre-pass + compact_topk_ids: torch.Tensor # [max_rows * num_topk] int32, for micro kernel pre-pass- compact_topk_ids=torch.empty(state_E, dtype=torch.int32, device=device), + compact_topk_ids=torch.empty( + max_rows * num_topk, dtype=torch.int32, device=device + ),Also applies to: 185-185
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` at line 138, The buffer compact_topk_ids is currently allocated with size [state_E] (per-expert) but must be sized for routed pairs (flattened routing ids) consumed by the micro pre-pass; change the allocation(s) of compact_topk_ids to length num_tokens * top_k (or the explicit routed_pairs_count/num_routed_pairs variable used in routing) and keep dtype int32, and update both places where compact_topk_ids is created/allocated so the micro pre-pass (which reads it as a flattened routing-id buffer) no longer overruns the array.
842-844:⚠️ Potential issue | 🟡 MinorClamp micro MAC against actual
M_tiles × N_tiles.
routed_rows * ceil(n / 128)still overcounts work when the routed rows fit inside one M tile, so the MAC cap can stay higher than the kernel has parallel work for.🔧 Proposed fix
- micro_work_tiles = max(1, routed_rows * max(1, (n + 128 - 1) // 128)) + m_tiles = max(1, (routed_rows + 127) // 128) + n_tiles = max(1, (n + 127) // 128) + micro_work_tiles = m_tiles * n_tiles🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines 842 - 844, The micro_work_tiles calculation can overcount when routed_rows < 128; replace routed_rows * ceil(n/128) with tile counts so work is M_tiles × N_tiles: compute m_tiles = max(1, (routed_rows + 128 - 1) // 128) and n_tiles = max(1, (n + 128 - 1) // 128) then set micro_work_tiles = m_tiles * n_tiles so micro_mac (computed via tuned_mac/_lookup_mac_ladder, base_mac and micro_work_tiles) is correctly clamped to the actual M_tiles × N_tiles parallel work.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/fused_moe/cute_dsl/b12x_moe.py`:
- Around line 217-240: The current workspace preallocation uses
select_sm120_moe_backend based only on routed rows, which can choose "dynamic"
even when launch_sm120_moe would force static (dynamic is only valid when
num_local_experts == num_experts); update the allocation logic around
max_routed_rows so that before calling allocate_sm120_dynamic_workspace you also
require self.num_local_experts == self.num_experts (mirror launch_sm120_moe),
otherwise force allocate_sm120_static_workspace; reference
select_sm120_moe_backend, max_routed_rows, allocate_sm120_dynamic_workspace,
allocate_sm120_static_workspace, launch_sm120_moe and the use_cuda_graph/backend
inference behavior to ensure the preallocated workspace cannot lock in an
invalid dynamic backend.
- Around line 64-67: The APIs currently accept arbitrary output_dtype but the
SM12x kernels in moe_dispatch.py are hardcoded to cutlass.BFloat16, so add an
explicit runtime guard: in each function or constructor that accepts the
parameter named output_dtype (e.g., the b12x_moe signature and the other
occurrences of output_dtype in this file), assert or raise a clear ValueError
unless output_dtype is torch.bfloat16; include a short error message referencing
launch_sm120_moe/moe_dispatch hardcoding so callers know why only BF16 is
allowed. Ensure the same check is applied to all other places where output_dtype
is accepted in this module.
---
Duplicate comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Line 138: The buffer compact_topk_ids is currently allocated with size
[state_E] (per-expert) but must be sized for routed pairs (flattened routing
ids) consumed by the micro pre-pass; change the allocation(s) of
compact_topk_ids to length num_tokens * top_k (or the explicit
routed_pairs_count/num_routed_pairs variable used in routing) and keep dtype
int32, and update both places where compact_topk_ids is created/allocated so the
micro pre-pass (which reads it as a flattened routing-id buffer) no longer
overruns the array.
- Around line 842-844: The micro_work_tiles calculation can overcount when
routed_rows < 128; replace routed_rows * ceil(n/128) with tile counts so work is
M_tiles × N_tiles: compute m_tiles = max(1, (routed_rows + 128 - 1) // 128) and
n_tiles = max(1, (n + 128 - 1) // 128) then set micro_work_tiles = m_tiles *
n_tiles so micro_mac (computed via tuned_mac/_lookup_mac_ladder, base_mac and
micro_work_tiles) is correctly clamped to the actual M_tiles × N_tiles parallel
work.
In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 38-56: Replace the custom is_sm100_family() check with the
repo-standard helper: import and use flashinfer.utils.is_sm100a_supported() for
the skip gate; specifically remove or stop using the local is_sm100_family() and
change the sm100_required pytest.mark.skipif(...) to call
flashinfer.utils.is_sm100a_supported(), ensuring the test uses the common
utility (and add the import if missing).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 6076caa1-507c-4917-b9bf-cd104a8b7d8c
📒 Files selected for processing (4)
flashinfer/fused_moe/cute_dsl/b12x_moe.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.pyflashinfer/fused_moe/cute_dsl/fused_moe.pytests/moe/test_cute_dsl_fused_moe.py
|
/bot stop |
|
/bot run |
|
The GitLab CI pipeline #48713073 has been cancelled. |
|
/bot stop |
|
The GitLab CI pipeline #48714380 has been cancelled. |
|
/bot run |
| @@ -0,0 +1,1334 @@ | |||
| """ | |||
| Copyright (c) 2025 by FlashInfer team. | |||
|
@flashinfer-bot run |
📌 Description
Summary
New SM120/SM121 MoE APIs (
b12x_fused_moe,B12xMoEWrapper) with:max(0,x)²) for non-gated MoE (Nemotron-Super) across all three SM120 kernel backends (micro, static, dynamic)cutlass_fused_moeandcute_dsl_fp4_block_scale_moeroutines, with corrected TFLOPS/bandwidth calculations for non-gated activationsb12x_fused_moe, SM100 keepscute_dsl_fused_moe_nvfp4API separation
cute_dsl_fused_moe_nvfp4(FP4 input)CuteDslMoEWrapperb12x_fused_moe(bf16 input)B12xMoEWrapperThe SM100 APIs (
cute_dsl_fused_moe_nvfp4,CuteDslMoEWrapper) are restored to SM100-only scope — no SM120 dispatch, noactivation_typeparameter.Micro kernel
Ported from b12x. Selected automatically when
routed_rows ≤ 20(top_k=1) or≤ 40(top_k>1). Key optimizations vs the static kernel:all_rows_uniquefast path: whennum_tokens=1and every expert is unique, skips atomic row counting and uses O(1) work-tile assignmentReLU2 activation
Added
activationparameter ("silu"default,"relu2") to all SM120 kernel classes viaself.is_gatedcompile-time branching (cutlass.const_expr):StorageGated(3 pipelines, gate+up buffers) vsStorageRelu2(2 pipelines, single FC1 buffer)silu(gate) * upvsrelu(x)²Exposed through
activation_typeparameter onCuteDslMoEWrapperandcute_dsl_fused_moe_nvfp4APIs.API usage
Functional
Wrapper (CUDA graph compatible)
Example micro benchmarks
🔍 Related Issues
#3013
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests