Skip to content

feat: Add b12x CuTe DSL fused MoE for SM120#3066

Merged
bkryu merged 12 commits intoflashinfer-ai:mainfrom
bkryu:b12x_cute_dsl_moe_v4
Apr 15, 2026
Merged

feat: Add b12x CuTe DSL fused MoE for SM120#3066
bkryu merged 12 commits intoflashinfer-ai:mainfrom
bkryu:b12x_cute_dsl_moe_v4

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Apr 14, 2026

📌 Description

Summary

  • Extend cute_dsl_fused_moe_nvfp4 and CuteDslMoEWrapper to support SM120/SM121 GPUs
  • Port the b12x fused MoE kernels (static + dynamic) using CuTe DSL
  • Add workspace cache for the functional API to avoid per-call allocation overhead

Architecture

The SM120 MoE kernel fuses the entire MoE operation into a single persistent kernel launch:

  • Route+Pack → FC1 (gate+up) → SwiGLU → Quant → FC2 (down) → Scatter-add

Two kernel variants are auto-selected by routed pair count:

  • Static (routed_rows ≤ 640): Resident-grid barrier between route/pack and compute. Optimized for decode.
  • Dynamic (routed_rows > 640): Lock-free task queue with producer/consumer overlap. Optimized for prefill.

Key differences from the SM100 CuTe DSL path:

  • Takes bf16 input (kernel quantizes internally) — requires x_bf16 parameter
  • Non-interleaved gate/up weights (up first, gate second)

API:

Both cute_dsl_fused_moe_nvfp4 and CuteDslMoEWrapper.run() gain an x_bf16 parameter:

# Functional API                                                                           
cute_dsl_fused_moe_nvfp4(x=fp4_quantized, x_sf=..., ..., x_bf16=bf16_tensor)

# Wrapper API (recommended for production — pre-allocated workspace)                                                              
moe = CuteDslMoEWrapper(num_experts=256, top_k=8, ..., use_cuda_graph=True)
moe.run(x=fp4_quantized, x_sf=..., ..., x_bf16=bf16_tensor)  

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added Blackwell (SM120/SM121) support for fused Mixture-of-Experts with optimized static and dynamic kernels, FP4 quantization/packing helpers, and a BF16 execution path for Blackwell devices.
  • New Features (CLI)

    • New CLI flag to toggle a functional execution mode for Blackwell runs.
  • Tests

    • Updated tests for Blackwell compatibility, weight/layout handling, and functional-vs-wrapper consistency.
  • Chores

    • Benchmarks and runners updated to exercise the new Blackwell execution paths.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 14, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds SM120/SM121 (Blackwell) CuTe-DSL fused MoE support: new SM12x static/dynamic kernels, dispatch/workspace/weight-view management, FP4 quantization helpers, and updates to wrapper/functional APIs, tests, and benchmarks to enable SM12x runtime paths.

Changes

Cohort / File(s) Summary
Benchmark Registry & Routines
benchmarks/routines/flashinfer_benchmark_utils.py, benchmarks/routines/moe.py
Marked cute-dsl supported for compute capabilities 12.0/12.1; added --use_functional_api toggle in CLI and adapted benchmark runners to pass x_bf16 and select functional vs wrapper execution.
CuTe-DSL FP4 Helpers
flashinfer/cute_dsl/fp4_common.py
Added PTX user-op wrappers (global/shared stores, atomics, FP8→f32), and multiple @cute.jit helpers for FP4 quantize/pack, reductions, SiLU fusion, and scatter/atomic reductions.
SM12x Package Init
flashinfer/fused_moe/cute_dsl/blackwell_geforce/__init__.py
New package initializer exporting SM120/SM121 MoE kernel classes, workspace allocators, and launch entrypoints.
SM12x Dispatch & Workspace
flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py
New dispatch layer: workspace dataclasses/allocators, weight-view preparation and caching, compilation caches for static/dynamic kernels, backend selection, and unified launch_sm120_moe entry.
SM12x Kernels
flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_static_kernel.py, flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dynamic_kernel.py
Added MoEStaticKernel (resident-grid static-scheduled fused flow) and MoEDynamicKernel (queue-driven dynamic routed flow) implementing device routing/packing, FP4 quantization, FC compute, and scatter logic.
Core Wrapper & Functional API
flashinfer/fused_moe/cute_dsl/fused_moe.py
Extended compute-capability gating to include 120/121; added optional x_bf16 parameter, SM12x workspace/weight-view caching, and direct dispatch to launch_sm120_moe while preserving SM100/SM103 paths.
Tests
tests/moe/test_cute_dsl_fused_moe.py
Added is_sm120_family/CUDA-13 checks, adjusted weight preparation for SM12x (GEMM1 packing), and updated tests to supply x_bf16 and new gating markers.

Sequence Diagram

sequenceDiagram
    participant Client
    participant Wrapper as CuteDslMoEWrapper
    participant Dispatch as SM12x_Dispatch
    participant Static as SM120_StaticKernel
    participant Dynamic as SM120_DynamicKernel
    participant CUDA as CUDA_Runtime

    Client->>Wrapper: run(x, topk_ids, topk_weights, x_bf16)
    alt SM120/SM121 detected (CUDA≥13)
        Wrapper->>Dispatch: launch_sm120_moe(...)
        Dispatch->>Dispatch: select_sm120_moe_backend()
        alt static backend
            Dispatch->>Static: compile/cache & launch(...)
            Static->>CUDA: execute kernel (route/pack -> quantize -> FC1/SiLU/FC2 -> scatter)
        else dynamic backend
            Dispatch->>Dynamic: compile/cache & launch(...)
            Dynamic->>CUDA: execute kernel (histogram -> producer pack -> consumer compute -> scatter)
        end
        CUDA->>Client: moe_output
    else
        Wrapper->>Wrapper: fallback SM100/SM103 pipeline (autotuned)
        CUDA->>Client: moe_output
    end
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • yzh119
  • sricketts
  • samuellees
  • nv-yunzheq
  • cyx-6
  • aleozlx
  • kahyunnam

Poem

🐰 I hopped through Blackwell's gates at dawn,
kernels hum, FP4 bytes tucked in a yawn,
producers write, consumers weave the song,
warps dance, scales align, tasks march along,
MoE sprouts bloom — the rabbit grins and’s gone.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 52.04% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main change: adding SM120 support for CuTe DSL fused MoE kernels (b12x), which aligns with the extensive kernel additions and API extensions in the changeset.
Description check ✅ Passed The PR description follows the template structure with comprehensive Summary, Architecture, API sections and a completed checklist.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@bkryu bkryu added the run-ci label Apr 14, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Blackwell (SM120/SM121) architectures in the CuTe DSL MoE kernels, including both static and dynamic dispatch backends. The changes involve adding new kernel implementations, dispatch logic, and updating the wrapper and functional APIs to handle the fused quantization path required for SM120. My feedback highlights several areas for improvement: the hardcoding of the '32' constant in kernel stage calculations should be replaced with dynamic shape-based logic, the code duplication in the benchmark wrapper should be refactored, the function-level import in the dispatch layer should be moved to the module level, and the magic number offsets used for shared memory control structures should be replaced with named struct fields for better maintainability.

Comment on lines +358 to +359
while self.ab_stage > 1 and 32 % self.ab_stage != 0:
self.ab_stage -= 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The value 32 is hardcoded here. The comment indicates it's derived from K/tile_K. Hardcoding this value makes the kernel brittle and dependent on a specific problem size (e.g., K=4096). This should be computed dynamically based on the input tensor shapes to ensure correctness for different K values. If K is not available at this stage, it should be passed in or the logic should be moved to a place where it is available.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. The hardcoded 32 assumed K=4096 with tile_k=128. Fixed by passing hidden_size from a_input.shape[1] into_setup_attributes and computing k_tile_cnt = hidden_size // tile_shape_mnk[2] dynamically. Applied to both static and dynamic kernels (two locations each) in next commit

Comment on lines +528 to +529
while self.ab_stage > 1 and 32 % self.ab_stage != 0:
self.ab_stage -= 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The value 32 is hardcoded here. The comment indicates it's derived from K/tile_K. Hardcoding this value makes the kernel brittle and dependent on a specific problem size (e.g., K=4096). This should be computed dynamically based on the input tensor shapes to ensure correctness for different K values. If K is not available at this stage, it should be passed in or the logic should be moved to a place where it is available.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. The hardcoded 32 assumed K=4096 with tile_k=128. Fixed by passing hidden_size from a_input.shape[1] into_setup_attributes and computing k_tile_cnt = hidden_size // tile_shape_mnk[2] dynamically. Applied to both static and dynamic kernels (two locations each) in next commit

Comment thread benchmarks/routines/moe.py Outdated
Comment on lines +1362 to +1466
if use_functional:
from flashinfer import cute_dsl_fused_moe_nvfp4

if args.verbose >= 1:
print(
"[INFO] Using functional API (cute_dsl_fused_moe_nvfp4) with workspace cache"
)

# Pre-allocate output buffer to avoid per-call allocation
moe_output = torch.empty(
num_tokens, hidden_size, dtype=torch.bfloat16, device=device
)

# Warmup call to populate workspace cache before timed region
cute_dsl_fused_moe_nvfp4(
x=tensors["x"],
x_sf=tensors["x_sf"],
token_selected_experts=tensors["token_selected_experts"],
token_final_scales=tensors["token_final_scales"],
w1_weight=tensors["w1_weight"],
w1_weight_sf=tensors["w1_weight_sf"],
w1_alpha=tensors["w1_alpha"],
fc2_input_scale=tensors["fc2_input_scale"],
w2_weight=tensors["w2_weight"],
w2_weight_sf=tensors["w2_weight_sf"],
w2_alpha=tensors["w2_alpha"],
num_experts=num_experts,
top_k=top_k,
num_local_experts=local_num_experts,
x_bf16=tensors["x_bf16"],
moe_output=moe_output,
)

def run_cute_dsl_moe(
x,
x_sf,
token_selected_experts,
token_final_scales,
w1_weight,
w1_weight_sf,
w1_alpha,
fc2_input_scale,
w2_weight,
w2_weight_sf,
w2_alpha,
x_bf16=None,
):
return cute_dsl_fused_moe_nvfp4(
x=x,
x_sf=x_sf,
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
w1_weight=w1_weight,
w1_weight_sf=w1_weight_sf,
w1_alpha=w1_alpha,
fc2_input_scale=fc2_input_scale,
w2_weight=w2_weight,
w2_weight_sf=w2_weight_sf,
w2_alpha=w2_alpha,
num_experts=num_experts,
top_k=top_k,
num_local_experts=local_num_experts,
x_bf16=x_bf16,
moe_output=moe_output,
)
else:
moe = CuteDslMoEWrapper(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
use_cuda_graph=is_cuda_graph_compatible,
max_num_tokens=num_tokens,
num_local_experts=local_num_experts,
local_expert_offset=local_expert_offset,
)

def run_cute_dsl_moe(
x,
x_sf,
token_selected_experts,
token_final_scales,
w1_weight,
w1_weight_sf,
w1_alpha,
fc2_input_scale,
w2_weight,
w2_weight_sf,
w2_alpha,
x_bf16=None,
):
return moe.run(
x=x,
x_sf=x_sf,
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
w1_weight=w1_weight,
w1_weight_sf=w1_weight_sf,
w1_alpha=w1_alpha,
fc2_input_scale=fc2_input_scale,
w2_weight=w2_weight,
w2_weight_sf=w2_weight_sf,
w2_alpha=w2_alpha,
x_bf16=x_bf16,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's significant code duplication in this if/else block. The run_cute_dsl_moe function is defined twice with an identical long signature. This can be refactored to improve maintainability by defining the function only once. You can use a runner variable to hold either moe.run or a functools.partial or lambda wrapping cute_dsl_fused_moe_nvfp4, and then call this runner inside a single run_cute_dsl_moe implementation. This would make the code more DRY and easier to maintain.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored in next commit to define run_cute_dsl_moe once. The if/else now sets a runner variable — either a functools.partial wrapping the functional API or moe.run — and a single function definition delegates to it.

# We need the ORIGINAL physical storage, not .contiguous() of the view
# (which would write in permuted logical order).
# convert_sf_from_mma_layout reverses the permutation back to 2D swizzled.
from flashinfer.cute_dsl.utils import convert_sf_from_mma_layout
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Placing from imports inside functions is generally discouraged as it can hide dependencies and lead to unexpected behavior. It's better to have all imports at the top of the file for clarity and to ensure all dependencies are resolved at module load time. If this is done to avoid a circular dependency, consider refactoring the modules to break the cycle.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed — no circular dependency here. Moved convert_sf_from_mma_layout to the top-level import block alongside the other flashinfer.cute_dsl.utils imports in next commit


@cute.struct
class Storage:
ctrl: cute.struct.MemRange[cutlass.Int32, 8]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The shared memory ctrl buffer is defined as a MemRange of 8 Int32s and accessed using magic number offsets (e.g., 0, 4, 8, ..., 28). This makes the code hard to read and maintain. It would be much clearer to define ctrl as a cute.struct with named fields. This would allow accessing fields by name (e.g., storage.ctrl.batch_base) instead of using hardcoded offsets, improving readability and reducing the risk of errors.

Example:

@cute.struct
class CtrlBlock:
    has_task: cutlass.Int32
    done: cutlass.Int32
    expert_idx: cutlass.Int32
    m_tile_idx: cutlass.Int32
    slice_begin_idx: cutlass.Int32
    slice_count_val: cutlass.Int32
    valid_rows_val: cutlass.Int32
    batch_base: cutlass.Int32

# In Storage:
ctrl: CtrlBlock

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion. Added a documenting comment in the next commit that maps byte offsets to field names. Deferring the full struct refactor to a follow-up — the ctrl buffer is accessed via raw PTX ld.shared/st.shared intrinsics using byte addresses, and restructuring to named fields would require validating that CuTe DSL struct field access works correctly with these raw PTX operations.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🧹 Nitpick comments (1)
tests/moe/test_cute_dsl_fused_moe.py (1)

38-54: Use flashinfer.utils for architecture gating.

These helpers reimplement compute-capability checks with raw torch.cuda queries. Please switch the skips to flashinfer.utils.get_compute_capability() / is_sm100a_supported() so the test gating stays aligned with the rest of the suite.

As per coding guidelines, "Use flashinfer.utils functions (get_compute_capability(), is_sm90a_supported(), is_sm100a_supported()) to skip tests on unsupported GPU architectures".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_cute_dsl_fused_moe.py` around lines 38 - 54, Replace the local
torch.cuda-based helpers is_sm100_family and is_sm120_family with calls to the
shared flashinfer.utils helpers: use flashinfer.utils.is_sm100a_supported() (or
is_sm90a_supported() where appropriate) to gate SM100-family tests and
flashinfer.utils.get_compute_capability() only if you need the numeric
major/minor; remove direct torch.cuda.* calls in is_sm100_family and
is_sm120_family and instead call the corresponding flashinfer.utils functions so
the test gating aligns with the rest of the suite.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/routines/moe.py`:
- Around line 1376-1393: The functional benchmark calls to
cute_dsl_fused_moe_nvfp4 (and the second similar call around lines 1409-1426)
fail to forward the local_expert_offset, so benchmarks with a non-zero expert
partition offset time the wrong shard; fix by adding the local_expert_offset
argument to those cute_dsl_fused_moe_nvfp4 calls (e.g.,
local_expert_offset=local_expert_offset) so the functional path uses the same
offset as CuteDslMoEWrapper.

In `@flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py`:
- Around line 1287-1321: The dynamic backend can crash when num_experts !=
num_local_experts because dynamic kernels index per-local buffers (row_counts,
expert_write_rows) with topk_ids while compiled weights use global num_experts;
update the dispatch logic in select_sm120_moe_backend/where backend is handled
so that before calling launch_sm120_dynamic_moe you verify num_experts ==
num_local_experts (or equivalent global-to-local mapping exists); if they
differ, route the call to the static path (e.g., call launch_sm120_static_moe or
the existing static dispatch) using the same workspace resolution
(_get_cached_workspace) to avoid out-of-bounds indexing by topk_ids and
mismatched expert dims.
- Around line 1197-1206: The cache check can return a workspace that appears
large enough by raw routed_rows but lacks required 128-row tile padding used by
Sm120DynamicMoEWorkspace; change the reuse condition in the cache lookup (the
block that builds cache_key and reads _WORKSPACE_CACHE) to compare against the
tiled/padded requirement instead of routed_rows directly: compute required_tiles
= ceil(max(1, routed_rows) / 128) and required_rows = required_tiles * 128 (or
otherwise derive the per-physical-tile capacity the Sm120DynamicMoEWorkspace
uses) and only reuse cached if cached.max_rows >= required_rows; update any
related uses of max(1, routed_rows) to use this padded required_rows so
Sm120DynamicMoEWorkspace allocations honor physical_tiles * 128 semantics.
- Around line 185-248: The code fails for non-128-aligned n because SM120
kernels assume the gate/up split lands on a 128-column tile; before
creating/permute/concatenating w13 (i.e. before using w1_fp4/w2_fp4 and
computing w13 = w1_fp4.permute(...)), assert or enforce alignment of n against
_LEVEL_TILE_N: if n % _LEVEL_TILE_N != 0 either pad/repack the weight tensors to
the next multiple of _LEVEL_TILE_N (adjusting associated blockscale views
w1_blockscale/w2_blockscale and their convert_sf_from_mma_layout calls) or raise
a clear exception indicating non-128-aligned n; update any keys using n (the key
for _WEIGHT_CACHE) so cached entries reflect padded size.

In `@flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dynamic_kernel.py`:
- Around line 343-380: The current logic reduces self.ab_stage to divide the K
tiles but doesn't re-check the actual shared-memory footprint including the
extra staged sB_up/sSFB_up pair allocated in Storage, so a theoretically-valid
ab_stage can still overflow SMEM at runtime; update moe_dynamic_kernel.py to
mirror MoEStaticKernel._setup_attributes by computing the full SMEM usage after
calling self._dense_cls._make_smem_layouts (using the returned
a_smem_layout_staged, b_smem_layout_staged, sfa_smem_layout_staged,
sfb_smem_layout_staged, epi_smem_layout_staged and Storage's staged buffers) and
then loop-decrement self.ab_stage until both (a) it divides the K tile count and
(b) the total SMEM usage fits within self.smem_capacity, recalculating layouts
via _make_smem_layouts each time you change self.ab_stage; this ensures the
dynamic kernel's final ab_stage is clamped to the real shared-memory footprint
before launch.

In `@flashinfer/fused_moe/cute_dsl/fused_moe.py`:
- Around line 206-220: The SM120 dispatch call to launch_sm120_moe is missing
propagation of local_expert_offset so global expert IDs (token_selected_experts)
are not mapped to the local expert index space; update the launch_sm120_moe call
sites (the current call around the a=x block and the similar sites around lines
noted) to pass local_expert_offset (or compute and pass
token_selected_experts_local = token_selected_experts - local_expert_offset) and
ensure the kernel or pre-dispatch logic uses token_selected_experts_local when
indexing local weights (w1_weight/w2_weight) and when setting
num_local_experts/num_experts_local; in short, either add a local_expert_offset
parameter to launch_sm120_moe and apply it inside the implementation, or
subtract local_expert_offset from token_selected_experts right before dispatch
at each call site so sharded expert layouts get the correct local IDs.
- Around line 679-690: The cached _sm120_weight_views created by
_get_sm120_weight_views is reused across run() calls but never invalidated,
causing stale views if new w1_* or w2_* tensors (or their block scales/alphas)
are passed; update run() to detect changes and recompute the cache: either clear
self._sm120_weight_views (set to None) at the start of run() or compare incoming
tensor identities/metadata (w1_weight, w1_weight_sf, w1_alpha, w2_weight,
w2_weight_sf, w2_alpha, intermediate_size, hidden_size) against stored
references and call _get_sm120_weight_views again when any differ so the cached
views always match the current inputs.

In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 65-67: The mark sm100_only currently excludes SM120 by using
is_sm120_family() in its condition which skips all wrapper/tactics coverage on
SM120; change the test markers so wrapper/accuracy smoke tests run on SM120
while only tactic-specific cases remain gated to SM100. Concretely, remove the
is_sm120_family() check from the sm100_only definition (so it only skips when
not is_sm100_family()), update all occurrences (sm100_only at the shown
locations) accordingly, and introduce a separate marker (e.g., sm100_tactics =
pytest.mark.skipif(not is_sm100_family(), reason="tactics require SM100")) to
apply only to tactic-specific tests.

---

Nitpick comments:
In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 38-54: Replace the local torch.cuda-based helpers is_sm100_family
and is_sm120_family with calls to the shared flashinfer.utils helpers: use
flashinfer.utils.is_sm100a_supported() (or is_sm90a_supported() where
appropriate) to gate SM100-family tests and
flashinfer.utils.get_compute_capability() only if you need the numeric
major/minor; remove direct torch.cuda.* calls in is_sm100_family and
is_sm120_family and instead call the corresponding flashinfer.utils functions so
the test gating aligns with the rest of the suite.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 0865631a-9c33-4fc7-b0dd-ac0b65efeb94

📥 Commits

Reviewing files that changed from the base of the PR and between 8c93f92 and 69ff093.

📒 Files selected for processing (9)
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • benchmarks/routines/moe.py
  • flashinfer/cute_dsl/fp4_common.py
  • flashinfer/fused_moe/cute_dsl/blackwell_geforce/__init__.py
  • flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py
  • flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dynamic_kernel.py
  • flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_static_kernel.py
  • flashinfer/fused_moe/cute_dsl/fused_moe.py
  • tests/moe/test_cute_dsl_fused_moe.py

Comment thread benchmarks/routines/moe.py Outdated
Comment thread flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
Comment thread flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py Outdated
Comment thread flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py Outdated
Comment on lines +343 to +380
self.ab_stage, self.epi_stage = self._dense_cls._compute_stages(
self.tile_shape_mnk,
self.a_dtype,
self.b_dtype,
self.sf_dtype,
sfa_smem,
sfb_smem,
self.epi_tile,
cutlass.BFloat16,
self.smem_capacity,
self.occupancy,
)
# ab_stage must divide k_tile_cnt (K/tile_K = 4096/128 = 32) evenly.
# _compute_stages returns the max that fits in smem (e.g. 3), but
# 32%3!=0 causes pipeline phase mismatch. Round down to nearest divisor.
while self.ab_stage > 1 and 32 % self.ab_stage != 0:
self.ab_stage -= 1
self.epi_stage = 1
(
self.a_smem_layout_staged,
self.b_smem_layout_staged,
self.sfa_smem_layout_staged,
self.sfb_smem_layout_staged,
self.epi_smem_layout_staged,
) = self._dense_cls._make_smem_layouts(
self.tile_shape_mnk,
self.epi_tile,
self.a_dtype,
self.a_layout,
self.b_dtype,
self.b_layout,
self.ab_stage,
cutlass.BFloat16,
self.c_layout,
self.epi_stage,
self.sf_vec_size,
self.tiled_mma,
)
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Clamp ab_stage against the real shared-memory footprint.

_compute_stages() only sizes the dense-kernel layouts, but this kernel later allocates an extra staged sB_up/sSFB_up pair in Storage. Unlike MoEStaticKernel._setup_attributes(), nothing rechecks the final SMEM usage before launch, so a stage count that looks legal here can still overflow shared memory and fail the dynamic kernel at runtime. Please mirror the static kernel’s post-check and decrement loop here.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dynamic_kernel.py` around
lines 343 - 380, The current logic reduces self.ab_stage to divide the K tiles
but doesn't re-check the actual shared-memory footprint including the extra
staged sB_up/sSFB_up pair allocated in Storage, so a theoretically-valid
ab_stage can still overflow SMEM at runtime; update moe_dynamic_kernel.py to
mirror MoEStaticKernel._setup_attributes by computing the full SMEM usage after
calling self._dense_cls._make_smem_layouts (using the returned
a_smem_layout_staged, b_smem_layout_staged, sfa_smem_layout_staged,
sfb_smem_layout_staged, epi_smem_layout_staged and Storage's staged buffers) and
then loop-decrement self.ab_stage until both (a) it divides the K tile count and
(b) the total SMEM usage fits within self.smem_capacity, recalculating layouts
via _make_smem_layouts each time you change self.ab_stage; this ensures the
dynamic kernel's final ab_stage is clamped to the real shared-memory footprint
before launch.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _compute_stages output is already conservatively clamped (max(1, min(ab_stage, 4))) and then further reduced by the k_tile_cnt divisibility check. In practice this yields ab_stage=1 or 2 for all tested configurations, which fits within SM120's 232KB SMEM budget even with the extra sB_up/sSFB_up buffers. Adding the full SMEM post-check (as in the static kernel) would require implementing a separate _shared_storage_size_bytes() for the dynamic kernel's different struct layout — we'll add this in a follow-up if we see runtime SMEM overflow failures.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bkryu, understood — the conservative two-stage cap (max(1, min(ab_stage, 4)) → divisibility reduction) makes a runtime overflow unlikely in current configurations. I'll note this design decision for future reviews and won't re-raise it here.

Would you like me to open a follow-up GitHub issue to track adding _shared_storage_size_bytes() for the dynamic kernel and the corresponding post-check loop, so it doesn't get lost?


✏️ Learnings added
Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3066
File: flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dynamic_kernel.py:343-380
Timestamp: 2026-04-14T19:10:27.074Z
Learning: In `flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dynamic_kernel.py` (flashinfer-ai/flashinfer PR `#3066`), `MoEDynamicKernel._setup_attributes()` intentionally omits the full SMEM post-check loop present in `MoEStaticKernel`. The `_compute_stages` output is already conservatively clamped (`max(1, min(ab_stage, 4))`) and further reduced by the k_tile_cnt divisibility check, yielding ab_stage=1 or 2 in all tested configurations — well within SM120's 232KB SMEM budget even with the extra staged sB_up/sSFB_up pair. A proper `_shared_storage_size_bytes()` for the dynamic kernel's different struct layout would be needed for a full post-check; the maintainer deferred this to a follow-up. Do not re-flag the missing post-check as a bug.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.


🧠 Learnings used
Learnt from: TomerBN-Nvidia
Repo: flashinfer-ai/flashinfer PR: 3024
File: csrc/fused_moe/noAuxTcKernels.cu:351-369
Timestamp: 2026-04-12T12:18:22.194Z
Learning: In `csrc/fused_moe/noAuxTcKernels.cu` (flashinfer-ai/flashinfer PR `#3024`), the `routing_replay_out` validation in `NoAuxTc` intentionally does NOT check `replay.sizes()[0] >= num_tokens`. This is by design: with CUDA graphs, the buffer is pre-allocated at maximum batch size and reused across steps with varying `num_tokens`; the kernel only writes to indices `[0, num_tokens)` so a larger buffer is always safe. The same policy applies to `csrc/trtllm_fused_moe_kernel_launcher.cu` (documented at line ~1795). Do not flag the missing lower-bound dim0 check as a bug.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2904
File: flashinfer/quantization/kernels/mxfp8_quantize.py:114-116
Timestamp: 2026-03-27T20:33:11.994Z
Learning: In `flashinfer/quantization/kernels/mxfp8_quantize.py` (flashinfer-ai/flashinfer), `_compute_optimal_warps_for_k` must receive `sf_blocks_per_warp` as an explicit parameter (not use the global `SF_BLOCKS_PER_WARP=16` constant). The `MXFP8QuantizeSwizzledKernel` constructor calls it with `self._sf_blocks_per_warp`, which is `SF_BLOCKS_PER_WARP=16` when `use_2t_per_sf=True` and `SF_BLOCKS_PER_WARP_SMALL=8` when `use_2t_per_sf=False`. Using the wrong constant causes fractional `rows_per_block` (e.g., K=3072 4T/SF: 30 warps → 960 threads → 2.5 truncated to 2 → write race from excess threads overlapping the next block's first row). MXFP4 and NVFP4 are unaffected because they use 1 thread per SF block with no multi-thread variant.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2904
File: flashinfer/quantization/kernels/mxfp4_quantize.py:668-685
Timestamp: 2026-03-27T20:21:04.233Z
Learning: In `flashinfer/quantization/kernels/mxfp4_quantize.py` (flashinfer-ai/flashinfer), the swizzled MXFP4 layout requires K % 128 == 0, which guarantees that `num_sf_blocks_per_row = K // 32` is always a multiple of 4, so `padded_sf_cols == num_sf_blocks_per_row` — no SF column padding ever occurs. The final `scale_output.reshape(-1, num_sf_blocks_per_row)` at the end of `mxfp4_quantize_cute_dsl` is therefore correct for both linear and swizzled paths: it uses the logical SF column count (which downstream consumers expect), not a physically padded count. This matches the CUDA backend behavior (`sf.reshape((-1, input.shape[-1] // sf_vec_size))`). Do not flag this reshape as incorrect or suggest branching on `padded_sf_cols`.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2962
File: include/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuh:232-262
Timestamp: 2026-04-02T18:45:38.854Z
Learning: In `include/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuh` (flashinfer-ai/flashinfer PR `#2962`), the per-step `state_dst_slots` precompute has three mutually exclusive branches:
1. `dst_state_batch_indices` present → always write unless index == pad_slot_id (caller controls slots via pad_slot_id; no update_state gating needed).
2. `intermediate_states` present → always cache every step (no update_state gating needed).
3. Neither → only write at last step when params.update_state is true.
`intermediate_states_buffer` and `dst_state_batch_indices` are enforced mutually exclusive by a Python-side ValueError in `flashinfer/mamba/selective_state_update.py`. Do not flag the absence of `update_state` gating in branches 1 and 2 as a bug.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3007
File: tests/utils/test_norm.py:0-0
Timestamp: 2026-04-07T21:44:40.431Z
Learning: In `tests/utils/test_norm.py` (flashinfer-ai/flashinfer), when writing regression tests for large-stride (> INT32_MAX) tensor paths, using `torch.as_strided(small_buf, (M, H), (_INT64_STRIDE, 1))` on a small buffer is unsafe and will segfault because row 1 is at byte offset `2^31` beyond the allocation. The correct pattern is: allocate a flat buffer of at least `_INT64_STRIDE + H` elements, then create the strided view from it so every row is backed by real memory. For fused_add_rmsnorm tests, only the input tensor `x` needs to be non-contiguous (using the large flat buffer); the residual `r` can remain a normally-allocated contiguous tensor — one non-contiguous tensor is sufficient to trigger `is_contiguous() == False` in the kernel selection path, avoiding a second ~4 GB allocation.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2838
File: flashinfer/quantization/kernels/nvfp4_quantize.py:967-976
Timestamp: 2026-03-23T18:58:22.437Z
Learning: In `flashinfer/quantization/kernels/nvfp4_quantize.py` (flashinfer-ai/flashinfer), the TMA dispatch predicate `m.bit_length() - 1 + k.bit_length() - 1 >= _TMA_LOG2_MK_THRESHOLD` (i.e., floor(log2(M)) + floor(log2(K)) >= 25) is intentional. It is a deliberate approximation of the `M*K >= 2^25` threshold — not a bug. The maintainer acknowledged this and will add a clarifying comment in a follow-up commit. Do not flag this as incorrect or suggest replacing it with `m * k >= (1 << _TMA_LOG2_MK_THRESHOLD)`.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2904
File: flashinfer/quantization/kernels/mxfp8_quantize.py:384-385
Timestamp: 2026-03-27T20:51:45.564Z
Learning: In `flashinfer/quantization/kernels/mxfp8_quantize.py` (`MXFP8QuantizeSwizzledKernel`, small-K path), the padding-column zeroing in the swizzled small-K path requires a thread-stride loop, not a simple predicated write. Because `sf_col_idx = local_tidx // _threads_per_sf` is bounded by `[0, num_sf_blocks_per_row)`, a bare `if sf_col_idx >= num_sf_blocks_per_row` guard is unreachable. The correct pattern (matching MXFP4/NVFP4 swizzled kernels) is:
- Padding rows: loop starting at `sf_col_idx`, striding by `num_sf_blocks_per_row`, up to `padded_sf_cols`.
- Real rows: loop starting at `num_sf_blocks_per_row + sf_col_idx`, striding by `num_sf_blocks_per_row`, guarded by `const_expr(self.num_sf_blocks_per_row != self.padded_sf_cols)` so it is eliminated at compile time when `K/32` is a multiple of 4 (no column padding needed).

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2773
File: include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh:27-32
Timestamp: 2026-03-12T21:29:16.342Z
Learning: In `include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh` (flashinfer-ai/flashinfer), the `static_assert` inside the `PHILOX_ROUNDS > 0` block that restricts stochastic rounding to fp16 state (`std::is_same_v<state_t, half>`) is intentionally kept in the CUDA header close to the implementation rather than being guarded by a pre-JIT Python-side runtime check. The maintainer prefers this colocation for easier auditability. Do not suggest moving or duplicating this constraint to the Python layer.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2865
File: include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh:343-366
Timestamp: 2026-03-23T21:04:23.630Z
Learning: FlashInfer Mamba SSU MTP kernels: When applying the z gate in epilogues, index z using its own strides (z_stride_batch/z_stride_mtp), not the output strides. Verified in include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh (role_epilogue) and include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh (inline epilogue).

Learnt from: depaulmillz
Repo: flashinfer-ai/flashinfer PR: 2738
File: csrc/group_gemm_nvfp4_groupwise_sm120.cu:101-102
Timestamp: 2026-03-11T16:24:19.500Z
Learning: In `flashinfer/gemm/gemm_base.py`, for group GEMM kernel launchers (e.g., `group_gemm_nvfp4_nt_groupwise`, `group_gemm_mxfp4_nt_groupwise`, `group_gemm_fp8_nt_groupwise`), both `int_workspace_buffer` and `float_workspace_buffer` are always allocated via `_get_cache_buf(..., a.device)`, guaranteeing they share the same device as input tensor `a`. In `csrc/group_gemm_nvfp4_groupwise_sm120.cu` and similar launchers, sourcing the `ffi::CUDADeviceGuard` from `float_workspace_buffer.device()` and the stream from `A.device()` is safe and intentional — do not flag this as a device inconsistency.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 2777
File: flashinfer/norm/kernels/rmsnorm.py:101-103
Timestamp: 2026-03-13T01:09:32.828Z
Learning: In `flashinfer/norm/kernels/rmsnorm.py` and `flashinfer/norm/kernels/fused_add_rmsnorm.py`, kernel tuning parameters such as `use_async_copy` and shared memory estimates are computed using `torch.cuda.get_device_properties(torch.cuda.current_device())`, while `sm_version` is sourced from `input.device`. This is intentional and acceptable because flashinfer targets homogeneous multi-GPU systems where all devices share the same SM version and SMEM capacity. Do not flag this as a device inconsistency.

Learnt from: bkryu
Repo: flashinfer-ai/flashinfer PR: 3026
File: include/flashinfer/gemm/fp4_gemm_template_sm120.h:267-270
Timestamp: 2026-04-09T21:51:00.268Z
Learning: In flashinfer-ai/flashinfer, `include/flashinfer/gemm/fp4_gemm_template_sm120.h` is gated by `#define FLASHINFER_ENABLE_SM120` and is only included from `fp4_gemm_cutlass_template_sm120.h`, which is compiled exclusively for SM120/SM121 targets. Adding a runtime `Sm12xOnly` architecture guard inside this file is redundant — there is no code path that instantiates these kernels on non-SM12x hardware. Do not suggest adding such guards to this file.

Learnt from: kahyunnam
Repo: flashinfer-ai/flashinfer PR: 2965
File: tests/norm/test_fused_rmsnorm_silu.py:138-141
Timestamp: 2026-04-03T21:06:16.453Z
Learning: In `tests/norm/test_fused_rmsnorm_silu.py` (flashinfer-ai/flashinfer PR `#2965`), the full `ALL_LUT_SHAPES` test matrix (8 hidden sizes × 5 token counts, up to 399,360 tokens) across bf16, FP8, and NVFP4 is intentionally kept as the default CI parametrization. The maintainer confirmed the tests are fast and do not need to be split into a smoke subset vs. a slow marker. Do not flag this test matrix as too large for CI.

Learnt from: TomerBN-Nvidia
Repo: flashinfer-ai/flashinfer PR: 2685
File: flashinfer/fused_moe/core.py:1829-1829
Timestamp: 2026-03-05T08:44:44.454Z
Learning: When reviewing files that define fake ops decorated with register_fake_op (e.g., in flashinfer/fused_moe/*), ensure the function signatures exactly mirror the real op they stand in for. Unused parameters triggering ARG001 are intentional in these fake ops and should not be flagged or removed. Do not suggest adding no-op parameters like _ = param. This guideline applies to all such fake-op implementations within the flashinfer/fused_moe directory (and similar patterns elsewhere if applicable).

Comment thread flashinfer/fused_moe/cute_dsl/fused_moe.py
Comment thread flashinfer/fused_moe/cute_dsl/fused_moe.py Outdated
Comment thread tests/moe/test_cute_dsl_fused_moe.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (1)
tests/moe/test_cute_dsl_fused_moe.py (1)

38-67: Split the “supported family” check from the true SM100-only check.

is_sm100_family() now returns True for SM120 as well, so every caller has to remember that the name no longer matches the predicate. A broader helper name here, plus a real SM100-only helper, would make these skip markers much harder to misuse later.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_cute_dsl_fused_moe.py` around lines 38 - 67, The helper
is_sm100_family() currently returns True for SM120 devices, so split it into a
properly named broad helper (e.g., is_sm100_or_sm120_family or
is_supported_sm_family) and a true SM100-only helper (e.g., is_sm100_only) while
keeping is_sm120_family as-is; then update the skip markers sm100_required to
use the broad helper and sm100_only to use the new is_sm100_only helper (replace
references to is_sm100_family in those decorators) so names accurately reflect
predicates and avoid misuse.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/routines/moe.py`:
- Around line 1360-1369: The benchmark currently sets use_functional =
getattr(args, "use_functional_api", False) and imports cute_dsl_fused_moe_nvfp4
when true, but does not record this choice in the printed/saved metadata; update
the code that builds and prints/saves benchmark metadata (the same place that
emits existing run config and verbose info) to include a clear field like
"runner_mode" or "use_functional_api" set from use_functional, ensure the
console message and any persisted result/metadata blob include this flag/value
so runs using cute_dsl_fused_moe_nvfp4 are distinguishable, and apply the same
change to the other block range (lines around 1461-1535) where the alternative
runner is selected.
- Around line 1360-1369: The functional API path guarded by use_functional that
imports cute_dsl_fused_moe_nvfp4 should not be captured in a CUDA graph; modify
the control flow where CUDA-graph capture is applied so that when use_functional
is True (i.e., running cute_dsl_fused_moe_nvfp4) you run the benchmark loop
normally without graph capture, preserving per-call workspace allocations and
measurements instead of wrapping it in the graph capture logic used for the
non-functional path; update both occurrences around use_functional (lines
referencing cute_dsl_fused_moe_nvfp4 and the graph-capture block) to skip graph
creation/launch and run the functional call directly.

In `@flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py`:
- Around line 199-254: The cache key for _WEIGHT_CACHE currently only includes
FP4 and blockscale pointers causing stale w1_alphas/w2_alphas to be reused;
update the cache key construction (the key tuple created where
w1_fp4.data_ptr()... is used) to also incorporate the alpha buffers (e.g.,
w1_alphas.contiguous().data_ptr() and w2_alphas.contiguous().data_ptr() or a
stable identifier derived from their storage/shape/dtype) so that _WeightViews
created in this block (views._w13_sf_storage, views._down_sf_storage, and the
w1_alpha/w2_alpha copies passed into _WeightViews) are unique per alpha content
and not reused when alphas change. Ensure any contiguity conversion used for the
key matches the contiguity used when storing the copied alphas (references:
_WEIGHT_CACHE, w1_alphas, w2_alphas, and the _WeightViews construction).

In `@flashinfer/fused_moe/cute_dsl/fused_moe.py`:
- Around line 384-385: Public SM12x entrypoints decorated with
`@supported_compute_capability` and `@flashinfer_api` (the functions currently
annotated at the top with those decorators) must also include the repo-standard
`@backend_requirement` metadata; update each affected API (the decorators shown
and the other similar decorator sites for the SM12x-only bf16 entrypoints) to
add a `@backend_requirement`(...) decorator reflecting the same SM12x
capabilities/backends (e.g., include the SM12x backend and matching compute
capability set) placed above or alongside `@supported_compute_capability` so the
public API functions carry `@backend_requirement` in addition to `@flashinfer_api`
and `@supported_compute_capability`.
- Around line 690-711: The cache key for _sm120_weight_views misses the
per-expert alpha tensors; update the weight_key tuple used to detect
invalidation to include w1_alpha.data_ptr() and w2_alpha.data_ptr() so changes
to w1_alpha or w2_alpha will invalidate and rebuild via _get_sm120_weight_views
(referenced symbols: _sm120_weight_views, _sm120_weight_key, weight_key,
w1_alpha, w2_alpha, _get_sm120_weight_views).
- Around line 474-505: The pre-allocation currently creates only one SM120
workspace in _allocate_buffers() (using allocate_sm120_dynamic_workspace or
allocate_sm120_static_workspace) and stores it in self._sm120_workspace, which
can mismatch when launch_sm120_moe() reselects backend based on current routed
rows; fix by pre-allocating and storing both workspace variants (or storing
separate attributes, e.g., self._sm120_dynamic_workspace and
self._sm120_static_workspace) and update the codepath that calls
launch_sm120_moe() (and any wrapper creation) to select and pass the matching
workspace based on select_sm120_moe_backend(num_tokens/current_routed_rows) or
the actual routed_rows value, ensuring Sm120DynamicMoEWorkspace objects are only
given to the dynamic launcher and Sm120StaticMoEWorkspace objects only to the
static launcher (or lazily allocate the correct type on demand if absent).

---

Nitpick comments:
In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 38-67: The helper is_sm100_family() currently returns True for
SM120 devices, so split it into a properly named broad helper (e.g.,
is_sm100_or_sm120_family or is_supported_sm_family) and a true SM100-only helper
(e.g., is_sm100_only) while keeping is_sm120_family as-is; then update the skip
markers sm100_required to use the broad helper and sm100_only to use the new
is_sm100_only helper (replace references to is_sm100_family in those decorators)
so names accurately reflect predicates and avoid misuse.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c3aacbb1-14cd-4dec-9d9b-6c026dabb8c4

📥 Commits

Reviewing files that changed from the base of the PR and between 69ff093 and aba8593.

📒 Files selected for processing (6)
  • benchmarks/routines/moe.py
  • flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py
  • flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dynamic_kernel.py
  • flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_static_kernel.py
  • flashinfer/fused_moe/cute_dsl/fused_moe.py
  • tests/moe/test_cute_dsl_fused_moe.py

Comment thread benchmarks/routines/moe.py
Comment thread flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
Comment on lines +384 to 385
@supported_compute_capability([100, 103, 120, 121])
@flashinfer_api
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Use @backend_requirement on the public SM12x APIs.

These entrypoints now expose architecture-specific behavior/support (x_bf16 is SM12x-only, and the supported capability set was extended to 12.x), but they still rely only on @flashinfer_api / @supported_compute_capability. Please attach the repo-standard @backend_requirement metadata here as well. As per coding guidelines Use @backend_requirement decorator on APIs with architecture-specific requirements to track supported compute capabilities.

Also applies to: 628-629, 829-830

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` around lines 384 - 385, Public
SM12x entrypoints decorated with `@supported_compute_capability` and
`@flashinfer_api` (the functions currently annotated at the top with those
decorators) must also include the repo-standard `@backend_requirement` metadata;
update each affected API (the decorators shown and the other similar decorator
sites for the SM12x-only bf16 entrypoints) to add a `@backend_requirement`(...)
decorator reflecting the same SM12x capabilities/backends (e.g., include the
SM12x backend and matching compute capability set) placed above or alongside
`@supported_compute_capability` so the public API functions carry
`@backend_requirement` in addition to `@flashinfer_api` and
`@supported_compute_capability`.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd push back on this one. @backend_requirement is for APIs that dispatch across multiple backends via a backend= parameter (e.g., mm_fp4 choosing between cudnn/cutlass/cute-dsl/b12x). cute_dsl_fused_moe_nvfp4 is a specific backend — it has no backend=kwarg and no alternative implementations to select between.

This is consistent with every other single-backend MoE function in core.py: cutlass_fused_moe, trtllm_bf16_moe, trtllm_fp4_block_scale_moe, etc. — none use @backend_requirement.

@supported_compute_capability([100, 103, 120, 121]) already provides the is_compute_capability_supported() introspection method and CC gating. Adding @backend_requirement({}) on top would introduce the backend-dispatch machinery (skip_check, auto-selection, BackendSupportedError) with nothing to dispatch to.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Comment thread flashinfer/fused_moe/cute_dsl/fused_moe.py
Comment thread flashinfer/fused_moe/cute_dsl/fused_moe.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
flashinfer/fused_moe/cute_dsl/fused_moe.py (2)

434-434: Minor: Same unused minor variable nit.

-        major, minor = torch.cuda.get_device_capability(device)
+        major, _ = torch.cuda.get_device_capability(device)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` at line 434, The call to
torch.cuda.get_device_capability(device) assigns both major and minor but minor
is unused; update the assignment in fused_moe.py (the line using
torch.cuda.get_device_capability) to either unpack only the needed value (e.g.,
assign to major and ignore the second value) or replace the unused variable with
an underscore (e.g., major, _ = torch.cuda.get_device_capability(device)) to
avoid the unused-variable nit.

189-189: Minor: Prefix unused minor variable with underscore.

Static analysis correctly identifies that minor is unpacked but never used. Consider using _ prefix to indicate intentionally unused.

-        major, minor = torch.cuda.get_device_capability(x.device)
+        major, _ = torch.cuda.get_device_capability(x.device)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` at line 189, The unpacking from
torch.cuda.get_device_capability(x.device) assigns an unused minor value; update
the unpack to mark it unused (e.g., change the tuple target from "major, minor"
to "major, _" or "major, _minor") in fused_moe.py at the call site where
x.device is used so static analysis no longer flags the unused variable; ensure
there are no other references to "minor" in the surrounding code before making
the change.
flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py (1)

54-65: Unused helper function _as_grouped_scale_view.

This function is defined but not called within this file. If it's intended for external use or future work, consider adding a brief docstring note or moving it to a utilities module. Otherwise, it could be removed to reduce dead code.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py` around lines
54 - 65, The helper function _as_grouped_scale_view is defined but never used;
either remove it to eliminate dead code, or make its intent explicit by moving
it to a utilities module and/or documenting it: if it's meant for external
reuse, relocate the function (e.g., to a utils module), change its name to
remove the leading underscore if it should be public, add a brief docstring
stating its purpose and expected input shapes, and update any exports/imports
accordingly; otherwise delete the _as_grouped_scale_view definition from this
file.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py`:
- Around line 250-251: The cached alpha tensors stored in _get_weight_views
(w1_alphas, w2_alphas -> weights.w1_alpha, weights.w2_alpha) can remain scalar
[1] and therefore mismatch the kernel's expected (weight_E,) shape; fix by
expanding scalar alpha tensors to [num_experts] before caching (apply the same
logic as _expand_to_experts or call _expand_to_experts on w1_alphas/w2_alphas in
_get_weight_views), then make them contiguous and .to(torch.float32) so
launch_sm120_static_moe receives correctly-shaped alpha tensors when invoking
the kernel.

---

Nitpick comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py`:
- Around line 54-65: The helper function _as_grouped_scale_view is defined but
never used; either remove it to eliminate dead code, or make its intent explicit
by moving it to a utilities module and/or documenting it: if it's meant for
external reuse, relocate the function (e.g., to a utils module), change its name
to remove the leading underscore if it should be public, add a brief docstring
stating its purpose and expected input shapes, and update any exports/imports
accordingly; otherwise delete the _as_grouped_scale_view definition from this
file.

In `@flashinfer/fused_moe/cute_dsl/fused_moe.py`:
- Line 434: The call to torch.cuda.get_device_capability(device) assigns both
major and minor but minor is unused; update the assignment in fused_moe.py (the
line using torch.cuda.get_device_capability) to either unpack only the needed
value (e.g., assign to major and ignore the second value) or replace the unused
variable with an underscore (e.g., major, _ =
torch.cuda.get_device_capability(device)) to avoid the unused-variable nit.
- Line 189: The unpacking from torch.cuda.get_device_capability(x.device)
assigns an unused minor value; update the unpack to mark it unused (e.g., change
the tuple target from "major, minor" to "major, _" or "major, _minor") in
fused_moe.py at the call site where x.device is used so static analysis no
longer flags the unused variable; ensure there are no other references to
"minor" in the surrounding code before making the change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 18a08836-362b-4cb7-af82-b31f31595299

📥 Commits

Reviewing files that changed from the base of the PR and between aba8593 and 6d738bf.

📒 Files selected for processing (2)
  • flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py
  • flashinfer/fused_moe/cute_dsl/fused_moe.py

Comment on lines +250 to +251
w1_alpha=w1_alphas.contiguous().to(torch.float32),
w2_alpha=w2_alphas.contiguous().to(torch.float32),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if _expand_to_experts is used for other alpha-related parameters
# and verify the pattern is consistently applied
rg -n "_expand_to_experts" flashinfer/fused_moe/cute_dsl/blackwell_geforce/

# Check kernel fake tensor shapes for alpha parameters
rg -n "alpha_fake.*make_fake" flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py

Repository: flashinfer-ai/flashinfer

Length of output: 996


🏁 Script executed:

# Get the actual shape of alpha_fake and down_alpha_fake at the key lines
sed -n '410,420p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py
sed -n '1022,1032p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py

Repository: flashinfer-ai/flashinfer

Length of output: 738


🏁 Script executed:

# Check how alphas are used in the launch functions around lines 515-520 and 1117-1120
sed -n '510,530p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py
sed -n '1112,1135p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py

Repository: flashinfer-ai/flashinfer

Length of output: 1499


🏁 Script executed:

# Check the function signature and how weights are passed to kernel
sed -n '500,540p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py
sed -n '1100,1150p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py

Repository: flashinfer-ai/flashinfer

Length of output: 2934


🏁 Script executed:

# Get the complete compiled() kernel call for static kernel
sed -n '525,570p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py

Repository: flashinfer-ai/flashinfer

Length of output: 1497


🏁 Script executed:

# Get the complete compiled() kernel call for dynamic kernel
sed -n '1130,1180p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py

Repository: flashinfer-ai/flashinfer

Length of output: 1749


🏁 Script executed:

# Check the function signatures to see where alphas come from as parameters
sed -n '495,510p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py
sed -n '1100,1112p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py

Repository: flashinfer-ai/flashinfer

Length of output: 853


🏁 Script executed:

# Find all calls to _get_weight_views to see what alphas are passed
rg -n "_get_weight_views" flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 498


🏁 Script executed:

# Get the _get_weight_views function signature and full implementation
sed -n '240,280p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py

Repository: flashinfer-ai/flashinfer

Length of output: 1272


🏁 Script executed:

# Check where alphas come from in the cutlass wrapper/MoEWrapper
rg -n "w1_alpha\|w2_alpha" flashinfer/fused_moe/cute_dsl/blackwell_geforce/ -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Find the CuteDslMoEWrapper class and the run/call that passes w1_alpha to _get_weight_views
sed -n '1280,1310p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py

Repository: flashinfer-ai/flashinfer

Length of output: 1125


🏁 Script executed:

# Search more broadly for the wrapper class and run method
rg -n "class CuteDslMoEWrapper\|def run" flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py -A 3

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Find where w1_alpha parameter comes from in the outer wrapper
rg -n "w1_alpha" flashinfer/fused_moe/cute_dsl/fused_moe.py -B 3 -A 3 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 1996


🏁 Script executed:

# Find the signature and docstring of the wrapper to understand alpha shape
sed -n '100,200p' flashinfer/fused_moe/cute_dsl/fused_moe.py

Repository: flashinfer-ai/flashinfer

Length of output: 4197


🏁 Script executed:

# Find where w1_alpha is created/prepared before being passed to the wrapper
rg -n "w1_alpha\s*=" flashinfer/fused_moe/cute_dsl/fused_moe.py -B 2 -A 2 | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 1084


🏁 Script executed:

# Check if there are tests that show how w1_alpha is created
fd -t f "test.*moe" flashinfer/fused_moe/cute_dsl/ | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search for test files that use the MoE functions
find flashinfer -name "*test*moe*" -type f

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if there are examples showing how w1_alpha is structured
rg -n "w1_alpha\s*=.*torch\|w1_alpha\s*=.*ones\|w1_alpha\s*=.*zeros" flashinfer/ --type py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check the docstring more carefully to see if it specifies shape
sed -n '150,165p' flashinfer/fused_moe/cute_dsl/fused_moe.py

Repository: flashinfer-ai/flashinfer

Length of output: 873


🏁 Script executed:

# Search where w1_alpha is computed/prepared in upper-level code
rg -n "w1_alpha" flashinfer/ --type py -B 3 -A 3 | grep -E "(expand|shape|numel|torch\.)" | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 2960


🏁 Script executed:

# Check the CUTLASS backend or earlier code to see what shape alpha has
rg -n "def.*moe.*alpha\|w1_alpha.*shape\|w1_alpha.*numel" flashinfer/ --type py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if there's a pattern in how input_gs is created vs w1_alpha
rg -n "input_gs\s*=" flashinfer/fused_moe/cute_dsl/fused_moe.py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if w1_alpha is scalar or per-expert at any call site
rg -n "w1_alpha\s*=\s*\(" flashinfer/ --type py -A 2 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look at the CuteDslMoEWrapper.run signature more carefully to understand alpha shape
sed -n '1250,1280p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py

Repository: flashinfer-ai/flashinfer

Length of output: 1104


🏁 Script executed:

# Check if there's any shape assertion or validation for alphas
rg -n "assert.*alpha.*shape\|alpha.*shape.*assert\|\.shape\[0\].*expert" flashinfer/fused_moe/cute_dsl/blackwell_geforce/ --type py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search for where w1_alpha originates - check if it's prepared/quantized upstream
rg -n "\.w1_alpha\|\"w1_alpha\"\|'w1_alpha'" flashinfer/ --type py | grep -v "torch\|test" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check the CuteDslMoEWrapper init to see how w1_alpha is stored
sed -n '1220,1260p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py

Repository: flashinfer-ai/flashinfer

Length of output: 1205


🏁 Script executed:

# Look at the full launch_sm120_moe call to understand alpha shapes before it's called
sed -n '1280,1350p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py

Repository: flashinfer-ai/flashinfer

Length of output: 2549


🏁 Script executed:

# Check the static backend launch call
sed -n '1350,1400p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py

Repository: flashinfer-ai/flashinfer

Length of output: 639


🏁 Script executed:

# Check the full context of how input_gs parameter is used in both static and dynamic
rg -n "def launch_sm120_static_moe\|def launch_sm120_dynamic_moe" flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py -A 15

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Verify that input_gs in launch_sm120_moe is indeed w1_alpha
sed -n '1326,1375p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py

Repository: flashinfer-ai/flashinfer

Length of output: 1350


🏁 Script executed:

# Check if input_gs and weights.w1_alpha are the same or different in kernel call
# Look at lines 555-565 to see full kernel call with all parameters
sed -n '540,580p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py

Repository: flashinfer-ai/flashinfer

Length of output: 1312


🏁 Script executed:

# Trace the data flow: w1_alpha -> input_gs parameter -> expanded in launch functions -> passed to kernel
# Also check if weights.w1_alpha is actually used in the kernel or if it's duplicative
rg -n "weights\.w1_alpha\|weights\.w2_alpha" flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if there's kernel code or documentation that shows which parameters are actually used
# Look for the kernel implementation or fake tensor binding to understand parameter ordering
rg -n "alpha.*input\|alpha.*scale" flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dynamic_kernel.py -A 2 -B 2 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check the static kernel compilation to see parameter binding
sed -n '380,450p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py

Repository: flashinfer-ai/flashinfer

Length of output: 2054


🏁 Script executed:

# Check if w1_alpha from weights is redundant or actually used differently
# Search for any branching or conditional use of weights.w1_alpha
rg -n "weights\.w1_alpha\|weights\.w2_alpha" flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py -B 5 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Understand the kernel parameter order - get the compile() call to see parameter binding
sed -n '470,550p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py | grep -A 60 "compiled = cute.compile"

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if there's a distinction between input_gs and weights.w1_alpha in the kernel
# They might be different parameters (input_gs for GEMM1 input scale, weights.w1_alpha for weight scale)
rg -n "input_gs\|w1_alpha" flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dynamic_kernel.py -B 1 -A 1 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check the upstream caller to see what shape w1_alpha has when it enters launch_sm120_moe
rg -n "launch_sm120_moe\(" flashinfer/fused_moe/cute_dsl/fused_moe.py -B 5 -A 15

Repository: flashinfer-ai/flashinfer

Length of output: 1855


🏁 Script executed:

# Check the learnings again - are there any mentions of alpha shape handling in SM120?
# Also check if there's documentation about whether alphas should be scalar or per-expert
rg -n "scalar\|per-expert\|alpha.*shape" flashinfer/fused_moe/cute_dsl/fused_moe.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if there's any validation or shape checking for w1_alpha/w2_alpha before they reach launch_sm120_moe
rg -n "assert.*alpha\|alpha.*size\|alpha.*shape" flashinfer/fused_moe/cute_dsl/ --type py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Key question: is input_gs being used for a DIFFERENT purpose than weights.w1_alpha?
# Look at the kernel code to understand what each parameter does
sed -n '1,50p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dynamic_kernel.py

Repository: flashinfer-ai/flashinfer

Length of output: 2105


🏁 Script executed:

# Critical insight: check what input_gs actually is and if weights.w1_alpha is redundant
# Look at the kernel parameter binding in the compile call
sed -n '460,530p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py | tail -80

Repository: flashinfer-ai/flashinfer

Length of output: 2157


🏁 Script executed:

# Get the complete compile() call to see parameter order and understand what each fake tensor maps to
sed -n '461,530p' flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py

Repository: flashinfer-ai/flashinfer

Length of output: 2132


🏁 Script executed:

# Final verification: confirm that w1_alpha CAN be scalar [1] by checking if there's any branching
# or if alphas always come in [E] shape from upstream
rg -n "w1_alpha.shape\|w1_alpha.numel\|w1_alpha.*reshape\|w1_alpha.*expand" flashinfer/fused_moe/ --type py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if there's any documentation or comments about alpha shapes
rg -n "alpha.*scalar\|alpha.*per-expert\|alpha.*shape\|\[1\].*alpha" flashinfer/fused_moe/cute_dsl/ --type py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


Scalar alphas not expanded when cached, causing shape mismatch with kernel.

_get_weight_views stores alphas via .contiguous().to(torch.float32) without expanding scalar [1] tensors to [E] shape. The kernel is compiled with fake tensors expecting (weight_E,) shape for alpha_fake and down_alpha_fake, but weights.w1_alpha and weights.w2_alpha could remain [1].

Meanwhile, launch_sm120_static_moe expands the separate input_gs parameter via _expand_to_experts (lines 515-516), but passes the un-expanded cached weights.w1_alpha and weights.w2_alpha directly to the kernel at lines 557-558.

The _expand_to_experts function (line 484) explicitly handles scalar [1] tensors by expanding them to [num_experts]. This same logic should apply to cached alphas before they reach the kernel.

Suggested fix: expand alphas before caching
+    num_experts = w1_fp4.shape[0]
+
     views = _WeightViews(
         w13_fp4=w13.view(torch.float4_e2m1fn_x2),
         down_fp4=down.view(torch.float4_e2m1fn_x2),
         sfb_w13_ptr=make_ptr(
             sf_dtype,
             w13_sf_contiguous.data_ptr(),
             cute.AddressSpace.gmem,
             assumed_align=16,
         ),
         sfb_down_ptr=make_ptr(
             sf_dtype,
             down_sf_contiguous.data_ptr(),
             cute.AddressSpace.gmem,
             assumed_align=16,
         ),
-        w1_alpha=w1_alphas.contiguous().to(torch.float32),
-        w2_alpha=w2_alphas.contiguous().to(torch.float32),
+        w1_alpha=(
+            w1_alphas.expand(num_experts).contiguous().to(torch.float32)
+            if w1_alphas.numel() == 1
+            else w1_alphas.contiguous().to(torch.float32)
+        ),
+        w2_alpha=(
+            w2_alphas.expand(num_experts).contiguous().to(torch.float32)
+            if w2_alphas.numel() == 1
+            else w2_alphas.contiguous().to(torch.float32)
+        ),
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_geforce/moe_dispatch.py` around lines
250 - 251, The cached alpha tensors stored in _get_weight_views (w1_alphas,
w2_alphas -> weights.w1_alpha, weights.w2_alpha) can remain scalar [1] and
therefore mismatch the kernel's expected (weight_E,) shape; fix by expanding
scalar alpha tensors to [num_experts] before caching (apply the same logic as
_expand_to_experts or call _expand_to_experts on w1_alphas/w2_alphas in
_get_weight_views), then make them contiguous and .to(torch.float32) so
launch_sm120_static_moe receives correctly-shaped alpha tensors when invoking
the kernel.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 14, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !547 has been created, and the CI pipeline #48536124 is currently running. I'll report back once the pipeline job completes.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 14, 2026

/bot stop

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #48536124 has been cancelled.

Copy link
Copy Markdown
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left minor naming suggestion , lgtm

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 14, 2026

left minor naming suggestion , lgtm

Thanks. Renamed blackwell_geforce to blackwell_sm12x as per your suggestion.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 14, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !547 has been updated with latest changes, and the CI pipeline #48537831 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@nv-yunzheq nv-yunzheq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't inspect kernel code. But the rest of integration looks good to me

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 14, 2026

/bot stop

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #48537831 has been cancelled.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 14, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !547 has been updated with latest changes, and the CI pipeline #48545646 is currently running. I'll report back once the pipeline job completes.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 15, 2026

/bot stop

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 15, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #48545646 has been cancelled.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !547 has been updated with latest changes, and the CI pipeline #48549567 is currently running. I'll report back once the pipeline job completes.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 15, 2026

/bot stop

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #48549567 has been cancelled.

@bkryu bkryu mentioned this pull request Apr 15, 2026
5 tasks
@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 15, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !547 has been updated with latest changes, and the CI pipeline #48566011 is currently running. I'll report back once the pipeline job completes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants