Skip to content

feat: Add b12x_fused_moe / B12xMoEWrapper SM120 APIs with micro kernel and ReLU2#3080

Open
bkryu wants to merge 14 commits intoflashinfer-ai:mainfrom
bkryu:b12x_micro_kernel
Open

feat: Add b12x_fused_moe / B12xMoEWrapper SM120 APIs with micro kernel and ReLU2#3080
bkryu wants to merge 14 commits intoflashinfer-ai:mainfrom
bkryu:b12x_micro_kernel

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Apr 15, 2026

📌 Description

Summary

New SM120/SM121 MoE APIs (b12x_fused_moe, B12xMoEWrapper) with:

  • Micro kernel for tiny decode batches (≤20-40 routed rows) on SM120/SM121, with Triton routing compaction pre-pass and MAC tuning ladder
  • ReLU2 activation (max(0,x)²) for non-gated MoE (Nemotron-Super) across all three SM120 kernel backends (micro, static, dynamic)
  • Benchmark ReLU2 support for both cutlass_fused_moe and cute_dsl_fp4_block_scale_moe routines, with corrected TFLOPS/bandwidth calculations for non-gated activations
  • Clean API separation: SM120 uses b12x_fused_moe, SM100 keeps cute_dsl_fused_moe_nvfp4

API separation

GPU Functional API Wrapper API
SM100/SM103 cute_dsl_fused_moe_nvfp4 (FP4 input) CuteDslMoEWrapper
SM120/SM121 b12x_fused_moe (bf16 input) B12xMoEWrapper

The SM100 APIs (cute_dsl_fused_moe_nvfp4, CuteDslMoEWrapper) are restored to SM100-only scope — no SM120 dispatch, no activation_type parameter.

Micro kernel

Ported from b12x. Selected automatically when routed_rows ≤ 20 (top_k=1) or ≤ 40 (top_k>1). Key optimizations vs the static kernel:

  • Triton compact pre-pass: remaps global expert IDs to dense local indices, eliminating CAS-based expert discovery inside the kernel
  • all_rows_unique fast path: when num_tokens=1 and every expert is unique, skips atomic row counting and uses O(1) work-tile assignment
  • MAC tuning ladder: per-routed-row optimal cluster counts from b12x decode profiling, capped against hardware SM count to prevent deadlocks

ReLU2 activation

Added activation parameter ("silu" default, "relu2") to all SM120 kernel classes via self.is_gated compile-time branching (cutlass.const_expr):

  • Storage: StorageGated (3 pipelines, gate+up buffers) vs StorageRelu2 (2 pipelines, single FC1 buffer)
  • FC1: dual GEMM (gate+up) for SiLU vs single GEMM for ReLU2
  • Activation: silu(gate) * up vs relu(x)²
  • DMA: up-projection TMA loads eliminated for ReLU2

Exposed through activation_type parameter on CuteDslMoEWrapper and cute_dsl_fused_moe_nvfp4 APIs.

API usage

Functional

from flashinfer import b12x_fused_moe  

output = b12x_fused_moe(
    x=hidden_states_bf16,       # bf16 input (kernel fuses quantization)                                                                          
    w1_weight=w1_fp4, w1_weight_sf=w1_sf, w1_alpha=w1_alpha,                                                                                      
    fc2_input_scale=fc2_scale,                                                                                                                    
    w2_weight=w2_fp4, w2_weight_sf=w2_sf, w2_alpha=w2_alpha,                                                                                      
    token_selected_experts=topk_ids,                                                                                                              
    token_final_scales=topk_weights,                                                                                                              
    num_experts=512, top_k=22,                                                                                                                    
    activation="relu2",  # or "silu" (default)                                                                                                    
)                                                                                                                                                 

Wrapper (CUDA graph compatible)

from flashinfer import B12xMoEWrapper

moe = B12xMoEWrapper(                                                                                                                             
    num_experts=512, top_k=22,
    hidden_size=1024, intermediate_size=2688,                                                                                                     
    use_cuda_graph=True, activation="relu2",                                            
)                                                                                                                                                 
output = moe.run(x=hidden_states_bf16, ...)

Example micro benchmarks

# b12x cute dsl MoE for 1-token Nemotron 3 Super Size
python benchmarks/flashinfer_benchmark.py --routine cute_dsl_fp4_block_scale_moe --activation-type Relu2 --num_tokens 1 --hidden_size 1024 --intermediate_size 2688 --num_experts 512 --top_k 22 --use_cuda_events --num_iters 50
# Equivalent cutlass_fused_moe benchmark
python benchmarks/flashinfer_benchmark.py --routine cutlass_fused_moe --cutlass_variant nvfp4 --activation-type Relu2 --num_tokens 1 --hidden_size 1024 --intermediate_size 2688 --num_experts 512 --top_k 22 --quantized_input --use_cuda_events --num_iters 50

🔍 Related Issues

#3013

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Non‑gated ReLU2 activation and dual gated/non‑gated FC1 layouts for MoE; activation selectable at runtime.
    • New micro‑kernel backend plus routing‑ID compaction for improved single‑token/small‑batch performance.
    • SM12x (b12x) fused‑MoE functional API and CUDA‑graph‑friendly wrapper exported for SM12x workflows; runtime maps activations to kernel implementations.
    • CuTe‑DSL helpers added to support ReLU2 + FP4 quantization.
  • Tests

    • End‑to‑end tests for ReLU2, gated vs non‑gated flows, micro‑kernel paths, CUDA graph replay, and FP4 numerical agreement.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 15, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Added gated (SiLU) vs non‑gated (ReLU²) activation support across benchmarks and fused MoE CuTe‑DSL kernels; introduced SM12x b12x functional API and wrapper, a tiny‑decode micro backend with routing compaction, activation‑aware kernel compilation/storage, weight/layout/quantization changes, and new tests.

Changes

Cohort / File(s) Summary
Benchmarks & Utilities
benchmarks/routines/moe.py, benchmarks/routines/moe_utils.py
Propagated activation / is_gated into benchmarks and TFLOPs/bandwidth models; adjusted w1_rows/w1_cols accounting and benchmark wiring for gated vs non‑gated layouts.
FP4 Helpers
flashinfer/cute_dsl/fp4_common.py
Added relu2_16 and relu2_quantize_block_fp4 JIT helpers to support ReLU² fused with FP4 block quantization.
CuTe‑DSL Public Surface
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/__init__.py
Exported MoEMicroKernel in package __all__.
Routing Compaction Kernel
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py
New Triton kernel compact_topk_ids to compact flattened top‑k IDs into dense local indices and produce active_expert_count / weight_expert_ids mapping.
Dispatch & Micro Backend
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
Added micro (tiny‑decode) backend with micro kernel compile/cache path, routing compaction pre‑pass, activation propagation, MAC override ladder, and workspace field for compacted IDs.
Static Kernel
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.py
Added activation param; storage, pipelines, FC1 tiling, shared‑memory layout, and fused activation/quant logic made conditional on gated (SiLU) vs non‑gated (ReLU²).
Dynamic Kernel
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
Added activation param and self.is_gated; adjusted producer/consumer tiling, pipelines, and fused activation computation for gated vs non‑gated paths.
SM12x Functional API & Wrapper
flashinfer/fused_moe/cute_dsl/b12x_moe.py, flashinfer/fused_moe/cute_dsl/__init__.py, flashinfer/fused_moe/__init__.py, flashinfer/__init__.py
New b12x_fused_moe functional API and B12xMoEWrapper; enforce CUDA13+, activation propagation, backend selection, CUDA‑graph buffer preallocation, and public re‑exports.
CuteDSL NVFP4 Path & Wrapper Simplification
flashinfer/fused_moe/cute_dsl/fused_moe.py
Removed SM12x special‑case paths from NVFP4 flow and reduced supported compute capability decorators to SM100/103; unified NVFP4 runner/autotune path.
Tests
tests/moe/test_cute_dsl_fused_moe.py, tests/moe/test_b12x_fused_moe.py
Adjusted SM‑family gating and weight prep; added comprehensive SM12x b12x test suite (functional, wrapper, micro‑kernel, CUDA graph) including ReLU² coverage and numerical checks.

Sequence Diagram(s)

sequenceDiagram
    participant Caller as Caller
    participant Dispatch as Dispatch Layer
    participant Compact as Triton Compact
    participant Cache as Kernel Cache
    participant Micro as Micro Kernel
    participant Static as Static Kernel
    participant Dynamic as Dynamic Kernel
    participant Activation as Activation Func

    Caller->>Dispatch: submit token_selected_experts, weights, scales, activation
    Dispatch->>Dispatch: compute routed_rows = num_tokens * top_k
    Dispatch->>Dispatch: select backend (micro/static/dynamic) using cutovers

    alt Micro Path
        Dispatch->>Compact: compact_topk_ids(topk_ids)
        Compact-->>Dispatch: compact_ids, active_expert_count, weight_expert_ids
        Dispatch->>Cache: lookup/compile micro kernel (activation, mac_override)
        Cache-->>Dispatch: micro kernel
        Dispatch->>Micro: launch(compact_ids, activation, weights, scales)
        Micro->>Activation: apply SiLU or ReLU²
        Micro->>Caller: write outputs
    else Static Path
        Dispatch->>Cache: lookup/compile static kernel (activation)
        Cache-->>Dispatch: static kernel
        Dispatch->>Static: launch(topk_ids, activation, weights, scales)
        Static->>Activation: apply SiLU or ReLU²
        Static->>Caller: write outputs
    else Dynamic Path
        Dispatch->>Dynamic: launch dynamic kernel (activation)
        Dynamic->>Activation: apply activation variant
        Dynamic->>Caller: write outputs
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • yzh119
  • cyx-6
  • jiahanc
  • IwakuraRein
  • samuellees
  • jimmyzho
  • aleozlx

Poem

🐰 Hoppity-hop, experts line the track,
SiLU gates wiggle, ReLU² leaps back.
Tiny kernels hustle, compaction sings,
Quantized carrots fuel faster things.
A rabbit cheers: kernels, hop—great spring!

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main feature additions: new SM120 MoE APIs (b12x_fused_moe/B12xMoEWrapper) with micro kernel and ReLU2 activation support.
Description check ✅ Passed The description provides comprehensive detail: clear summary of changes, API separation table, micro kernel explanation, ReLU2 activation details, usage examples, and related issues. All required sections are substantially filled out.
Docstring Coverage ✅ Passed Docstring coverage is 81.48% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for non-gated MoE activations (ReLU2) and introduces a specialized micro-kernel for small decode batches on Blackwell SM120/SM121 architectures. The changes include updates to the static and dynamic CuTe DSL kernels, a new Triton-based ID compaction pre-pass, and updated benchmarking and testing utilities. Feedback indicates that the "moe_micro_kernel.py" file is missing from the PR and identifies a potential out-of-bounds risk when slicing the workspace buffer for compact IDs.

Comment thread flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
from .triton_compact import compact_topk_ids as _triton_compact_topk_ids

# Run Triton pre-pass to compact global expert IDs to dense local indices
compact_ids = workspace.compact_topk_ids[: flat_ids.numel()]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a potential out-of-bounds risk if flat_ids.numel() exceeds state_E. While the micro kernel path is currently restricted to routed_rows <= 40, the workspace allocation for compact_topk_ids uses state_E. It would be safer to ensure that the slice does not exceed the allocated size of the workspace buffer, or add an explicit check.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch on the defensive coding. In practice this can't overflow — compact_topk_ids is sized [state_E] (typically 256-512) while flat_ids.numel() is at most 40 on the micro path (the cutover threshold). But the invariant should be explicit. Added an assertion in the next commit

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
benchmarks/routines/moe.py (1)

1790-1813: ⚠️ Potential issue | 🟡 Minor

Pass is_gated into the FP8 bandwidth model too.

TFLOPS now distinguishes gated vs non-gated activations, but both FP8 bandwidth calls still rely on the default path. ReLU2 runs will therefore report inconsistent bandwidth.

🛠️ Suggested fix
     tb_per_sec = calculate_moe_kernel_bandwidth(
         num_tokens,
         hidden_size,
         intermediate_size,
         num_experts,
         top_k,
         median_time,
         input_dtype,
         weight_dtype,
         input_format="fp8",
         weight_format="fp8",
         routing_logits_dtype=routing_logits.dtype,
         active_experts=int(selected_experts.unique().numel()),
         verbose=args.verbose,
+        is_gated=args.activation_type in (ActivationType.Swiglu, ActivationType.Geglu),
     )
     tb_per_sec = calculate_moe_kernel_bandwidth(
         num_tokens,
         hidden_size,
         intermediate_size,
         num_experts,
         top_k,
         median_time,
         input_dtype,
         weight_dtype,
         input_format="fp8",
         weight_format="fp8",
         routing_logits_dtype=routing_logits.dtype,
         active_experts=int(selected_experts.unique().numel()),
         verbose=args.verbose,
+        is_gated=args.activation_type in (ActivationType.Swiglu, ActivationType.Geglu),
     )

Also applies to: 2025-2048

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/moe.py` around lines 1790 - 1813, The FP8 bandwidth call
is missing the is_gated flag, causing gated vs non-gated activations to report
inconsistent bandwidth; update the calculate_moe_kernel_bandwidth invocation(s)
to pass the same is_gated boolean used for calculate_moe_tflops (e.g.,
is_gated=args.activation_type in (ActivationType.Swiglu, ActivationType.Geglu))
so both calculate_moe_tflops(...) and calculate_moe_kernel_bandwidth(...)
receive the same gating hint (apply the same change to the other occurrences
around the later block that mirrors this code).
flashinfer/fused_moe/cute_dsl/fused_moe.py (1)

362-394: ⚠️ Potential issue | 🟠 Major

Reject relu2 outside the SM120/SM121 path.

These new public parameters are only honored in the SM120 branch. The fallback path still goes through _moe_core_impl(), which hard-wires the SwiGLU fusion helper, so activation_type="relu2" on SM100/SM103 can run the wrong math or hit mismatched FC1 shapes instead of failing fast.

🛠️ Suggested guard
@@
-        self.activation_type = activation_type
+        if activation_type not in {"silu", "relu2"}:
+            raise ValueError(f"Unsupported activation_type: {activation_type!r}")
+        self.activation_type = activation_type
@@
         major, minor = torch.cuda.get_device_capability(device)
         self._is_sm120 = major == 12
+        if activation_type != "silu" and not self._is_sm120:
+            raise ValueError(
+                "activation_type='relu2' is only supported on SM120/SM121"
+            )
 def cute_dsl_fused_moe_nvfp4(
@@
-    if num_local_experts is None:
+    if activation_type not in {"silu", "relu2"}:
+        raise ValueError(f"Unsupported activation_type: {activation_type!r}")
+
+    if num_local_experts is None:
         num_local_experts = num_experts
@@
     major, _ = torch.cuda.get_device_capability(x.device)
     if major == 12:
         ...
+    elif activation_type != "silu":
+        raise ValueError(
+            "activation_type='relu2' is only supported on SM120/SM121"
+        )

Also applies to: 827-916

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` around lines 362 - 394, The
constructor accepts activation_type but the non-SM120/SM121 path still calls
_moe_core_impl which assumes SwiGLU; add a runtime guard in the initializer (or
immediately before dispatch to _moe_core_impl) that checks activation_type and
the detected GPU SM version and either raise a clear error or restrict allowed
values when SM < 120 (e.g., if activation_type == "relu2" and not on SM120/121,
raise ValueError). Update the dispatch code path that calls _moe_core_impl (and
any fallback branches referenced around the alternate implementation) to enforce
this same check so relu2 is only honored on the SM120/SM121 branch and cannot
silently run with the SwiGLU helper.
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py (1)

138-185: ⚠️ Potential issue | 🔴 Critical

Size compact_topk_ids for routed rows, not experts.

Line 824 slices this buffer to flat_ids.numel() (num_tokens * top_k), but the workspace only allocates state_E entries. Any micro launch with more routed rows than local experts will write past the end of the buffer.

🛠️ Suggested fix
-    compact_topk_ids: torch.Tensor  # [state_E] int32, for micro kernel pre-pass
+    compact_topk_ids: torch.Tensor  # [max_rows] int32, for micro kernel pre-pass
@@
-        compact_topk_ids=torch.empty(state_E, dtype=torch.int32, device=device),
+        compact_topk_ids=torch.empty(max_rows, dtype=torch.int32, device=device),
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines
138 - 185, compact_topk_ids is currently allocated with length state_E but is
later sliced to flat_ids.numel() (num_tokens * top_k) in the micro-kernel
pre-pass, causing out-of-bounds writes when num_tokens > state_E; in
allocate_sm120_static_workspace change the compact_topk_ids allocation in
Sm120StaticMoEWorkspace to have capacity for the worst-case routed rows times
top-k (e.g. torch.empty(state_E * max_rows * num_topk, dtype=torch.int32,
device=device) or at minimum torch.empty(max_rows * state_E * num_topk, ...)) so
flat_ids.numel() can always fit, and keep references to compact_topk_ids,
allocate_sm120_static_workspace, Sm120StaticMoEWorkspace, num_topk, max_rows,
and state_E to locate the change.
🧹 Nitpick comments (1)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py (1)

76-86: Add a defensive size guard for micro-only usage.

This kernel is O(BLOCK²) in a single program; adding an explicit upper bound makes accidental large launches fail fast with a clear message.

Possible guardrail
     block = triton.next_power_of_2(total_pairs)
+    if block > 256:
+        raise ValueError(
+            f"compact_topk_ids is intended for micro batches; got total_pairs={total_pairs}"
+        )
     num_warps = 1 if block <= 16 else 2
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py` around lines
76 - 86, Add a defensive size guard before launching _compact_topk_ids_kernel to
prevent accidental large O(BLOCK²) launches: compute block =
triton.next_power_of_2(total_pairs) as you do, then check against a small hard
limit (e.g. MAX_BLOCK = 64 or 128) and/or a MAX_PAIRS derived limit and raise a
clear RuntimeError if block > MAX_BLOCK (include block and total_pairs in the
message). Keep the existing num_warps logic and kernel args
(_compact_topk_ids_kernel, topk_ids, compact_topk_ids, weight_expert_ids,
active_expert_count, total_pairs, BLOCK=block, num_warps=num_warps) unchanged;
just insert the guard using the same symbols so oversized launches fail fast
with a descriptive error.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/routines/moe.py`:
- Around line 1367-1371: The code silently maps unsupported ActivationType
values to "silu"; change the logic to validate args.activation_type against the
supported mapping instead of defaulting. Use the _ACT_STR dict to look up
activation_str and if the activation_type is not present raise a clear exception
(e.g., ValueError) mentioning the unsupported ActivationType and listing
supported keys; also compute is_gated from ActivationType.Geglu and
ActivationType.Swiglu as before but ensure Geglu is rejected if not in _ACT_STR
so it cannot silently run the SiLU kernel.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py`:
- Around line 539-541: Comments describing w13 tile ordering conflict with
actual code: update the inline comments around the w13 TMA descriptor creation
(the lines that mention "Gate tiles at N=..." and "up tiles at N=...") to
reflect the actual ordering used by the code path where gate_slice_idx =
intermediate_slice + gate_tile_cnt (i.e., up tiles occupy the first half of
N-tiles and gate tiles the second half). Locate the call to
self._dense_cls._make_tma_atoms_and_tensors that produces tma_b_w13/gB_w13 and
any other similar comments (also around the other occurrence near lines
~1168-1171) and change the wording so it states "Up tiles at N=0..I_tp/tile_N-1,
Gate tiles at N=I_tp/tile_N..2*I_tp/tile_N-1" or equivalent that matches the
gate_slice_idx logic.

In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 1093-1096: Replace the custom skip gating that defines
sm120_cuda13 (which currently uses is_sm120_family() and _has_cuda_13()) with
the repository-standard capability checks from flashinfer.utils or the API
capability method; specifically, remove is_sm120_family()/_has_cuda_13() and use
the appropriate flashinfer.utils helper (e.g., is_sm120_supported() or analogous
is_sm90a_supported()/is_sm100a_supported()) or call
api_name.is_compute_capability_supported(cc) to decide the skip. Ensure the new
marker still uses pytest.mark.skipif(...) with a descriptive reason string
indicating the required SM/CUDA capability.

---

Outside diff comments:
In `@benchmarks/routines/moe.py`:
- Around line 1790-1813: The FP8 bandwidth call is missing the is_gated flag,
causing gated vs non-gated activations to report inconsistent bandwidth; update
the calculate_moe_kernel_bandwidth invocation(s) to pass the same is_gated
boolean used for calculate_moe_tflops (e.g., is_gated=args.activation_type in
(ActivationType.Swiglu, ActivationType.Geglu)) so both calculate_moe_tflops(...)
and calculate_moe_kernel_bandwidth(...) receive the same gating hint (apply the
same change to the other occurrences around the later block that mirrors this
code).

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Around line 138-185: compact_topk_ids is currently allocated with length
state_E but is later sliced to flat_ids.numel() (num_tokens * top_k) in the
micro-kernel pre-pass, causing out-of-bounds writes when num_tokens > state_E;
in allocate_sm120_static_workspace change the compact_topk_ids allocation in
Sm120StaticMoEWorkspace to have capacity for the worst-case routed rows times
top-k (e.g. torch.empty(state_E * max_rows * num_topk, dtype=torch.int32,
device=device) or at minimum torch.empty(max_rows * state_E * num_topk, ...)) so
flat_ids.numel() can always fit, and keep references to compact_topk_ids,
allocate_sm120_static_workspace, Sm120StaticMoEWorkspace, num_topk, max_rows,
and state_E to locate the change.

In `@flashinfer/fused_moe/cute_dsl/fused_moe.py`:
- Around line 362-394: The constructor accepts activation_type but the
non-SM120/SM121 path still calls _moe_core_impl which assumes SwiGLU; add a
runtime guard in the initializer (or immediately before dispatch to
_moe_core_impl) that checks activation_type and the detected GPU SM version and
either raise a clear error or restrict allowed values when SM < 120 (e.g., if
activation_type == "relu2" and not on SM120/121, raise ValueError). Update the
dispatch code path that calls _moe_core_impl (and any fallback branches
referenced around the alternate implementation) to enforce this same check so
relu2 is only honored on the SM120/SM121 branch and cannot silently run with the
SwiGLU helper.

---

Nitpick comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py`:
- Around line 76-86: Add a defensive size guard before launching
_compact_topk_ids_kernel to prevent accidental large O(BLOCK²) launches: compute
block = triton.next_power_of_2(total_pairs) as you do, then check against a
small hard limit (e.g. MAX_BLOCK = 64 or 128) and/or a MAX_PAIRS derived limit
and raise a clear RuntimeError if block > MAX_BLOCK (include block and
total_pairs in the message). Keep the existing num_warps logic and kernel args
(_compact_topk_ids_kernel, topk_ids, compact_topk_ids, weight_expert_ids,
active_expert_count, total_pairs, BLOCK=block, num_warps=num_warps) unchanged;
just insert the guard using the same symbols so oversized launches fail fast
with a descriptive error.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5e3f31a2-4325-47e8-9db8-0acd0b10bef6

📥 Commits

Reviewing files that changed from the base of the PR and between 25b324d and ea4ad45.

📒 Files selected for processing (11)
  • benchmarks/routines/moe.py
  • benchmarks/routines/moe_utils.py
  • flashinfer/cute_dsl/fp4_common.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/__init__.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py
  • flashinfer/fused_moe/cute_dsl/fused_moe.py
  • tests/moe/test_cute_dsl_fused_moe.py

Comment thread benchmarks/routines/moe.py Outdated
Comment thread flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
Comment thread tests/moe/test_cute_dsl_fused_moe.py Outdated
@bkryu bkryu added the v0.6.9 release blocker label for 0.6.9 label Apr 15, 2026
@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 15, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !556 has been created, and the CI pipeline #48636109 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/routines/moe.py`:
- Around line 2030-2037: The bandwidth computation is still using the default
gated accounting while TFLOPS now selects non-gated accounting for some
activations; update the call to calculate_moe_kernel_bandwidth to pass the same
is_gated boolean used for calculate_moe_tflops (i.e.,
is_gated=args.activation_type in (ActivationType.Swiglu, ActivationType.Geglu)
or the expression that produced non-gated for Relu2) so both TFLOPS and kernel
bandwidth use the same activation-aware gating flag (refer to
calculate_moe_tflops and calculate_moe_kernel_bandwidth to locate the calls).
- Around line 1795-1802: The TFLOPS call uses args.activation_type to set
is_gated but this routine still constructs gated FC1 tensors and
run_fp8_block_moe never receives an activation flag, so reported TFLOPS can
diverge from the executed kernel; fix by not switching the is_gated flag based
on args.activation_type here — either hard-code is_gated=True (gated-only path)
or derive is_gated from the same gated-only indicator used when building tensors
(e.g., the 2 * intermediate_size gated FC1 logic) and/or update
run_fp8_block_moe to accept and forward an activation_type so activation-based
toggles are consistent with calculate_moe_tflops; reference
calculate_moe_tflops, run_fp8_block_moe, args.activation_type and the gated FC1
construction (2 * intermediate_size) when making the change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 43a02c6d-a59b-416a-b805-d1e691a5397e

📥 Commits

Reviewing files that changed from the base of the PR and between ea4ad45 and 58b1168.

📒 Files selected for processing (6)
  • benchmarks/routines/moe.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.py
  • tests/moe/test_cute_dsl_fused_moe.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/moe/test_cute_dsl_fused_moe.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py

Comment thread benchmarks/routines/moe.py
Comment thread benchmarks/routines/moe.py
@bkryu bkryu changed the title feat: Add micro kernel + ReLU2 activation for SM120 b12x fused MoE feat: Add b12x_fused_moe / B12xMoEWrapper SM120 APIs with micro kernel and ReLU2 Apr 16, 2026
Copy link
Copy Markdown
Collaborator

@nv-yunzheq nv-yunzheq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally look good to me. Left a few comments to make sure sm10x cute dsl moe wasn't get changed with the change

num_local_experts=num_experts_local,
scatter_output=moe_output,
)
# NOTE: SM120/SM121 dispatch is handled by callers (CuteDslMoEWrapper.run
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this as well?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Yunzhe for catching this. Doing another scan to remove references to sm120/121/b12x in the cute dsl MoE path.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the comments are not correct. Maybe the cleaner for this file is to just to restore the status of before b12x being added

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto as above. Reverting to

x: NVFP4-quantized input [num_tokens, hidden_size // 2]. 
x_sf: Scale factors for x.

Comment thread tests/moe/test_cute_dsl_fused_moe.py Outdated
sm100_required = pytest.mark.skipif(
not is_sm100_family() or (is_sm120_family() and not _has_cuda_13()),
reason="Requires SM100/SM103 or SM120/SM121 GPU (SM120 requires CUDA 13+)",
not is_sm100_family() or is_sm120_family(),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also maybe restore this file to the status where b12x being added. I think the logic here is not correct

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch let me revert tests/moe/test_cute_dsl_fused_moe.py

token_final_scales: torch.Tensor,
num_experts: int,
top_k: int,
num_local_experts: Optional[int] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest putting some non-essential args after , *, as keyword-only args

tho output, act etc are common enough. idk about w1 w2 alpha tho, are they like global scales?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point for clarify. Reordering in the next commit.

Regarding your question, yes, the w1 an w2 alphas are per-expert global scales

Copy link
Copy Markdown
Collaborator Author

@bkryu bkryu Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Also, not just reordered but added the * as suggested)

@@ -66,12 +66,12 @@ def _has_cuda_13():
not is_cute_dsl_available(), reason="CuteDSL not available"
)
sm100_required = pytest.mark.skipif(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will discuss it more in a call for me to catch up understanding

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point and I should have been more clear. Reverting the entire file to be at the state prior to any b12x work because I am moving b12x moe tests that run on sm12x to a separate test file

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
tests/moe/test_cute_dsl_fused_moe.py (1)

38-56: ⚠️ Potential issue | 🟠 Major

Use the repo capability helper here, not a local major == 10 check.

This reintroduces custom architecture gating and drops the repo-standard supported-runtime check. On an SM10x environment that the API still considers unsupported, these tests will now run instead of skip and fail for the wrong reason. Please gate sm100_required with flashinfer.utils.is_sm100a_supported(...) or the API capability method.

🛠️ Proposed fix
+from flashinfer.utils import is_sm100a_supported
+
 def is_sm100_family():
-    """Check for SM100 family (Blackwell: SM100, SM103).
+    """Check for a supported SM100/SM103 runtime.
 
     CuteDSL MoE NVFP4 kernels on SM10x use cute_dsl_fused_moe_nvfp4 API.
     SM120/121 tests are in test_b12x_fused_moe.py instead.
     """
     if not torch.cuda.is_available():
         return False
-    props = torch.cuda.get_device_properties(0)
-    return props.major == 10
+    return is_sm100a_supported(torch.cuda.current_device())

As per coding guidelines tests/**/*.py: Skip test execution on unsupported GPU architectures using flashinfer.utils check functions (is_sm90a_supported(), is_sm100a_supported(), etc.) or API methods like api_name.is_compute_capability_supported(cc).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_cute_dsl_fused_moe.py` around lines 38 - 56, Replace the
custom is_sm100_family() check and the sm100_required skip marker with the
repo-standard capability check: remove or stop using is_sm100_family() and
instead gate sm100_required using flashinfer.utils.is_sm100a_supported(...) (or
the API capability method api_name.is_compute_capability_supported(cc)) so the
pytest.mark.skipif uses the repo helper; update references to the skip decorator
(sm100_required) to call flashinfer.utils.is_sm100a_supported() and ensure
torch.cuda.is_available() logic is handled by that helper rather than checking
props.major == 10 in the is_sm100_family function.
🧹 Nitpick comments (1)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py (1)

543-764: Consider extracting shared fake tensor generation.

The _get_micro_kernel function duplicates most of the fake tensor generation code from _get_static_kernel (lines 611-730 mirror 385-503). Consider extracting a helper function to reduce duplication and improve maintainability.

💡 Example helper extraction
def _make_moe_fake_tensors(
    state_E: int,
    weight_E: int,
    m: int,
    k: int,
    w1_rows: int,
    n: int,
    num_topk: int,
    max_rows: int,
    rows_pad_k: int,
    cols_pad_k: int,
    topk_ids_dtype: torch.dtype,
):
    """Build fake tensors for MoE kernel compilation."""
    ab_dtype = cutlass.Float4E2M1FN
    sf_dtype = cutlass.Float8E4M3FN
    a_dtype = cutlass.BFloat16
    alpha_dtype = cutlass.Float32
    # ... shared fake tensor creation ...
    return (a_input_fake, topk_ids_fake, ..., token_weights_fake)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines
543 - 764, The fake-tensor construction in _get_micro_kernel duplicates the
block from _get_static_kernel; extract the shared creation into a helper (e.g.,
_make_moe_fake_tensors) that accepts the identifying params (state_E, weight_E,
m, k, w1_rows, n, num_topk, max_rows, rows_pad_k, cols_pad_k, topk_ids_dtype)
and returns the tuple of fake tensors used by both functions (a_input_fake,
topk_ids_fake, topk_weights_fake, packed_a_fake, sfa_fake,
packed_a_storage_fake, scale_storage_fake, barrier_count_fake,
barrier_epoch_fake, b_w13_fake, sfb_w13_fake, b_down_fake, sfb_down_fake,
row_counts_fake, active_expert_count_fake, weight_expert_ids_fake,
global_to_local_expert_fake, input_gs_fake, alpha_fake, down_alpha_fake,
global_scale_fake, scatter_fake, token_map_fake, token_weights_fake); replace
the duplicate blocks in _get_micro_kernel and _get_static_kernel to call this
helper and use its returned values before compilation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Around line 842-844: The micro_work_tiles computation is over-counting by
multiplying routed_rows by N_tiles; update micro_work_tiles to compute M_tiles =
max(1, (routed_rows + 128 - 1) // 128) and N_tiles = max(1, (n + 128 - 1) //
128) and then set micro_work_tiles = M_tiles * N_tiles (or max(1, M_tiles *
N_tiles) if desired). Keep the rest of the logic (calling
_lookup_mac_ladder(_MICRO_MAC_LADDER, routed_rows) to get tuned_mac and
computing micro_mac = min(tuned_mac or base_mac, micro_work_tiles, base_mac))
unchanged so MAC clamping correctly reflects actual M_tiles × N_tiles
parallelism.
- Line 138: The compact_topk_ids buffer is undersized: change its
allocation/annotation (currently declared as compact_topk_ids: torch.Tensor  #
[state_E]) to have capacity for flattened routed pairs (size max_rows *
num_topk, or at minimum max_rows) so flat_ids.numel() (num_tokens * top_k) never
exceeds workspace.compact_topk_ids.numel(); update the workspace struct/creation
sites (search for compact_topk_ids in moe_dispatch.py and the other occurrence
around line 185) to allocate a 1D int32 tensor of length max_rows * num_topk and
adjust any comments/annotations accordingly so the assertion at the
flat_ids.numel() <= workspace.compact_topk_ids.numel() check (around the 827-831
area) passes reliably.

---

Duplicate comments:
In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 38-56: Replace the custom is_sm100_family() check and the
sm100_required skip marker with the repo-standard capability check: remove or
stop using is_sm100_family() and instead gate sm100_required using
flashinfer.utils.is_sm100a_supported(...) (or the API capability method
api_name.is_compute_capability_supported(cc)) so the pytest.mark.skipif uses the
repo helper; update references to the skip decorator (sm100_required) to call
flashinfer.utils.is_sm100a_supported() and ensure torch.cuda.is_available()
logic is handled by that helper rather than checking props.major == 10 in the
is_sm100_family function.

---

Nitpick comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Around line 543-764: The fake-tensor construction in _get_micro_kernel
duplicates the block from _get_static_kernel; extract the shared creation into a
helper (e.g., _make_moe_fake_tensors) that accepts the identifying params
(state_E, weight_E, m, k, w1_rows, n, num_topk, max_rows, rows_pad_k,
cols_pad_k, topk_ids_dtype) and returns the tuple of fake tensors used by both
functions (a_input_fake, topk_ids_fake, topk_weights_fake, packed_a_fake,
sfa_fake, packed_a_storage_fake, scale_storage_fake, barrier_count_fake,
barrier_epoch_fake, b_w13_fake, sfb_w13_fake, b_down_fake, sfb_down_fake,
row_counts_fake, active_expert_count_fake, weight_expert_ids_fake,
global_to_local_expert_fake, input_gs_fake, alpha_fake, down_alpha_fake,
global_scale_fake, scatter_fake, token_map_fake, token_weights_fake); replace
the duplicate blocks in _get_micro_kernel and _get_static_kernel to call this
helper and use its returned values before compilation.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4c52a6fd-d2eb-4171-a08e-39319f27fbbc

📥 Commits

Reviewing files that changed from the base of the PR and between 210e542 and 3481279.

📒 Files selected for processing (3)
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
  • tests/moe/test_cute_dsl_fused_moe.py

active_expert_count: torch.Tensor # [1] int32
weight_expert_ids: torch.Tensor # [state_E] int32
global_to_local_expert: torch.Tensor # [weight_E] int32
compact_topk_ids: torch.Tensor # [state_E] int32, for micro kernel pre-pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: compact_topk_ids buffer is undersized — will cause assertion failures.

The compact_topk_ids buffer is allocated with size state_E (number of local experts), but it's used to store compacted routing IDs which have size num_tokens * top_k (routed pairs).

At line 827-831, the code asserts flat_ids.numel() <= workspace.compact_topk_ids.numel(), where flat_ids.numel() = num_tokens * top_k. With typical MoE configs (e.g., 8 local experts but up to 40 routed pairs for micro kernel), this assertion will fail.

The buffer should be sized to max_rows * num_topk or at minimum max_rows to accommodate the flattened routing IDs.

🐛 Proposed fix
 `@dataclass`(kw_only=True)
 class Sm120StaticMoEWorkspace:
     """Scratch buffers for one SM120 static MoE launch."""
     ...
-    compact_topk_ids: torch.Tensor  # [state_E] int32, for micro kernel pre-pass
+    compact_topk_ids: torch.Tensor  # [max_rows * num_topk] int32, for micro kernel pre-pass
     workspace = Sm120StaticMoEWorkspace(
         ...
-        compact_topk_ids=torch.empty(state_E, dtype=torch.int32, device=device),
+        compact_topk_ids=torch.empty(max_rows * num_topk, dtype=torch.int32, device=device),
     )

Also applies to: 185-185

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` at line 138,
The compact_topk_ids buffer is undersized: change its allocation/annotation
(currently declared as compact_topk_ids: torch.Tensor  # [state_E]) to have
capacity for flattened routed pairs (size max_rows * num_topk, or at minimum
max_rows) so flat_ids.numel() (num_tokens * top_k) never exceeds
workspace.compact_topk_ids.numel(); update the workspace struct/creation sites
(search for compact_topk_ids in moe_dispatch.py and the other occurrence around
line 185) to allocate a 1D int32 tensor of length max_rows * num_topk and adjust
any comments/annotations accordingly so the assertion at the flat_ids.numel() <=
workspace.compact_topk_ids.numel() check (around the 827-831 area) passes
reliably.

Comment on lines +842 to +844
micro_work_tiles = max(1, routed_rows * max(1, (n + 128 - 1) // 128))
tuned_mac = _lookup_mac_ladder(_MICRO_MAC_LADDER, routed_rows)
micro_mac = min(tuned_mac or base_mac, micro_work_tiles, base_mac)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Incorrect work tile count calculation may cause suboptimal MAC selection.

The formula routed_rows * max(1, (n + 128 - 1) // 128) multiplies raw routed rows by N-dimension tile count, but should instead compute the actual number of work tiles as M_tiles × N_tiles.

This over-estimates the work tile count (e.g., for routed_rows=20, n=4096: current = 20 * 32 = 640 vs correct = ceil(20/128) * 32 = 32), which may prevent MAC from being properly clamped to available parallelism.

🔧 Proposed fix
-        micro_work_tiles = max(1, routed_rows * max(1, (n + 128 - 1) // 128))
+        m_tiles = max(1, (routed_rows + 127) // 128)
+        n_tiles = max(1, (n + 127) // 128)
+        micro_work_tiles = m_tiles * n_tiles
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
micro_work_tiles = max(1, routed_rows * max(1, (n + 128 - 1) // 128))
tuned_mac = _lookup_mac_ladder(_MICRO_MAC_LADDER, routed_rows)
micro_mac = min(tuned_mac or base_mac, micro_work_tiles, base_mac)
m_tiles = max(1, (routed_rows + 127) // 128)
n_tiles = max(1, (n + 127) // 128)
micro_work_tiles = m_tiles * n_tiles
tuned_mac = _lookup_mac_ladder(_MICRO_MAC_LADDER, routed_rows)
micro_mac = min(tuned_mac or base_mac, micro_work_tiles, base_mac)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines
842 - 844, The micro_work_tiles computation is over-counting by multiplying
routed_rows by N_tiles; update micro_work_tiles to compute M_tiles = max(1,
(routed_rows + 128 - 1) // 128) and N_tiles = max(1, (n + 128 - 1) // 128) and
then set micro_work_tiles = M_tiles * N_tiles (or max(1, M_tiles * N_tiles) if
desired). Keep the rest of the logic (calling
_lookup_mac_ladder(_MICRO_MAC_LADDER, routed_rows) to get tuned_mac and
computing micro_mac = min(tuned_mac or base_mac, micro_work_tiles, base_mac))
unchanged so MAC clamping correctly reflects actual M_tiles × N_tiles
parallelism.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 16, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !556 has been updated with latest changes, and the CI pipeline #48708465 is currently running. I'll report back once the pipeline job completes.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 16, 2026

/bot stop

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 16, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #48708465 has been cancelled.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !556 has been updated with latest changes, and the CI pipeline #48713073 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (3)
tests/moe/test_cute_dsl_fused_moe.py (1)

38-56: ⚠️ Potential issue | 🟡 Minor

Use flashinfer.utils.is_sm100a_supported() for the skip gate.

This reintroduces a custom SM10x check and drops the repo-standard CUDA-version guard. The helper in flashinfer.utils already matches the intended SM100/SM103 coverage, so using it keeps the test gate consistent with the rest of the suite.

🔧 Proposed fix
 import pytest
 import torch
 from torch.nn import functional as F

 from flashinfer.cute_dsl import is_cute_dsl_available
+from flashinfer.utils import is_sm100a_supported
-
-
-def is_sm100_family():
-    """Check for SM100 family (Blackwell: SM100, SM103).
-
-    CuteDSL MoE NVFP4 kernels are optimized for SM10x architecture.
-    """
-    if not torch.cuda.is_available():
-        return False
-    props = torch.cuda.get_device_properties(0)
-    return props.major == 10

 
 # Skip decorators
 cute_dsl_available = pytest.mark.skipif(
     not is_cute_dsl_available(), reason="CuteDSL not available"
 )
 sm100_required = pytest.mark.skipif(
-    not is_sm100_family(),
-    reason="Requires SM100 family GPU (Blackwell: SM100, SM103, SM110)",
+    not is_sm100a_supported(torch.device("cuda")),
+    reason="Requires supported SM10x GPU",
 )

As per coding guidelines: tests/**/*.py: Skip test execution on unsupported GPU architectures using flashinfer.utils check functions (is_sm90a_supported(), is_sm100a_supported(), etc.) or API methods like api_name.is_compute_capability_supported(cc).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_cute_dsl_fused_moe.py` around lines 38 - 56, Replace the
custom is_sm100_family() check with the repo-standard helper: import and use
flashinfer.utils.is_sm100a_supported() for the skip gate; specifically remove or
stop using the local is_sm100_family() and change the sm100_required
pytest.mark.skipif(...) to call flashinfer.utils.is_sm100a_supported(), ensuring
the test uses the common utility (and add the import if missing).
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py (2)

138-138: ⚠️ Potential issue | 🔴 Critical

Size compact_topk_ids for routed pairs, not experts.

This buffer is still allocated as [state_E], but the micro pre-pass consumes it as a flattened routing-id buffer of length num_tokens * top_k. That will trip the assertion at Line 827 for small-expert configs where routed pairs exceed local experts.

🐛 Proposed fix
-    compact_topk_ids: torch.Tensor  # [state_E] int32, for micro kernel pre-pass
+    compact_topk_ids: torch.Tensor  # [max_rows * num_topk] int32, for micro kernel pre-pass
-        compact_topk_ids=torch.empty(state_E, dtype=torch.int32, device=device),
+        compact_topk_ids=torch.empty(
+            max_rows * num_topk, dtype=torch.int32, device=device
+        ),

Also applies to: 185-185

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` at line 138,
The buffer compact_topk_ids is currently allocated with size [state_E]
(per-expert) but must be sized for routed pairs (flattened routing ids) consumed
by the micro pre-pass; change the allocation(s) of compact_topk_ids to length
num_tokens * top_k (or the explicit routed_pairs_count/num_routed_pairs variable
used in routing) and keep dtype int32, and update both places where
compact_topk_ids is created/allocated so the micro pre-pass (which reads it as a
flattened routing-id buffer) no longer overruns the array.

842-844: ⚠️ Potential issue | 🟡 Minor

Clamp micro MAC against actual M_tiles × N_tiles.

routed_rows * ceil(n / 128) still overcounts work when the routed rows fit inside one M tile, so the MAC cap can stay higher than the kernel has parallel work for.

🔧 Proposed fix
-        micro_work_tiles = max(1, routed_rows * max(1, (n + 128 - 1) // 128))
+        m_tiles = max(1, (routed_rows + 127) // 128)
+        n_tiles = max(1, (n + 127) // 128)
+        micro_work_tiles = m_tiles * n_tiles
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines
842 - 844, The micro_work_tiles calculation can overcount when routed_rows <
128; replace routed_rows * ceil(n/128) with tile counts so work is M_tiles ×
N_tiles: compute m_tiles = max(1, (routed_rows + 128 - 1) // 128) and n_tiles =
max(1, (n + 128 - 1) // 128) then set micro_work_tiles = m_tiles * n_tiles so
micro_mac (computed via tuned_mac/_lookup_mac_ladder, base_mac and
micro_work_tiles) is correctly clamped to the actual M_tiles × N_tiles parallel
work.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/fused_moe/cute_dsl/b12x_moe.py`:
- Around line 217-240: The current workspace preallocation uses
select_sm120_moe_backend based only on routed rows, which can choose "dynamic"
even when launch_sm120_moe would force static (dynamic is only valid when
num_local_experts == num_experts); update the allocation logic around
max_routed_rows so that before calling allocate_sm120_dynamic_workspace you also
require self.num_local_experts == self.num_experts (mirror launch_sm120_moe),
otherwise force allocate_sm120_static_workspace; reference
select_sm120_moe_backend, max_routed_rows, allocate_sm120_dynamic_workspace,
allocate_sm120_static_workspace, launch_sm120_moe and the use_cuda_graph/backend
inference behavior to ensure the preallocated workspace cannot lock in an
invalid dynamic backend.
- Around line 64-67: The APIs currently accept arbitrary output_dtype but the
SM12x kernels in moe_dispatch.py are hardcoded to cutlass.BFloat16, so add an
explicit runtime guard: in each function or constructor that accepts the
parameter named output_dtype (e.g., the b12x_moe signature and the other
occurrences of output_dtype in this file), assert or raise a clear ValueError
unless output_dtype is torch.bfloat16; include a short error message referencing
launch_sm120_moe/moe_dispatch hardcoding so callers know why only BF16 is
allowed. Ensure the same check is applied to all other places where output_dtype
is accepted in this module.

---

Duplicate comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Line 138: The buffer compact_topk_ids is currently allocated with size
[state_E] (per-expert) but must be sized for routed pairs (flattened routing
ids) consumed by the micro pre-pass; change the allocation(s) of
compact_topk_ids to length num_tokens * top_k (or the explicit
routed_pairs_count/num_routed_pairs variable used in routing) and keep dtype
int32, and update both places where compact_topk_ids is created/allocated so the
micro pre-pass (which reads it as a flattened routing-id buffer) no longer
overruns the array.
- Around line 842-844: The micro_work_tiles calculation can overcount when
routed_rows < 128; replace routed_rows * ceil(n/128) with tile counts so work is
M_tiles × N_tiles: compute m_tiles = max(1, (routed_rows + 128 - 1) // 128) and
n_tiles = max(1, (n + 128 - 1) // 128) then set micro_work_tiles = m_tiles *
n_tiles so micro_mac (computed via tuned_mac/_lookup_mac_ladder, base_mac and
micro_work_tiles) is correctly clamped to the actual M_tiles × N_tiles parallel
work.

In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 38-56: Replace the custom is_sm100_family() check with the
repo-standard helper: import and use flashinfer.utils.is_sm100a_supported() for
the skip gate; specifically remove or stop using the local is_sm100_family() and
change the sm100_required pytest.mark.skipif(...) to call
flashinfer.utils.is_sm100a_supported(), ensuring the test uses the common
utility (and add the import if missing).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6076caa1-507c-4917-b9bf-cd104a8b7d8c

📥 Commits

Reviewing files that changed from the base of the PR and between 3481279 and 85511bc.

📒 Files selected for processing (4)
  • flashinfer/fused_moe/cute_dsl/b12x_moe.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
  • flashinfer/fused_moe/cute_dsl/fused_moe.py
  • tests/moe/test_cute_dsl_fused_moe.py

Comment thread flashinfer/fused_moe/cute_dsl/b12x_moe.py
Comment thread flashinfer/fused_moe/cute_dsl/b12x_moe.py Outdated
@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 16, 2026

/bot stop

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 16, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #48713073 has been cancelled.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !556 has been updated with latest changes, and the CI pipeline #48714380 is currently running. I'll report back once the pipeline job completes.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 16, 2026

/bot stop

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #48714380 has been cancelled.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 16, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !556 has been updated with latest changes, and the CI pipeline #48739189 is currently running. I'll report back once the pipeline job completes.

@@ -0,0 +1,1334 @@
"""
Copyright (c) 2025 by FlashInfer team.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2026?

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 18, 2026

@flashinfer-bot run

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

op: moe run-ci v0.6.9 release blocker label for 0.6.9

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants