Skip to content

WIP: B12x micro kernel merged#3098

Open
askliar wants to merge 21 commits intoflashinfer-ai:mainfrom
askliar:b12x_micro_kernel_merged
Open

WIP: B12x micro kernel merged#3098
askliar wants to merge 21 commits intoflashinfer-ai:mainfrom
askliar:b12x_micro_kernel_merged

Conversation

@askliar
Copy link
Copy Markdown
Contributor

@askliar askliar commented Apr 17, 2026

Summary by CodeRabbit

  • New Features

    • Added SM120/SM121 b12x fused MoE functional API and CUDA-graph-friendly wrapper, plus a micro-kernel path for small-token workloads.
    • Added ReLU² ("relu2") activation support for FP4 MoE.
  • Benchmarks

    • FLOPs and bandwidth calculations now distinguish gated vs non-gated activations.
  • Tests

    • Added comprehensive b12x fused MoE tests covering functional, wrapper, micro-kernel, CUDA-graph, and ReLU² paths.
  • Chores

    • Exposed new fused-MoE symbols for CuteDSL-enabled installs; improved a hardware-probe fallback.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 17, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds SM120/SM121 CuTe-DSL b12x fused‑MoE implementation (functional API + wrapper), threads activation mode (SiLU vs ReLU²) through dispatch/kernels/benchmarks, adds Triton compaction and FP4 ReLU² quant helpers, updates CUTLASS/NVFP4 paths and benchmarks, and adds comprehensive SM12x tests including CUDA-graph and micro-kernel coverage.

Changes

Cohort / File(s) Summary
Benchmarks
benchmarks/routines/moe.py, benchmarks/routines/moe_utils.py
Threaded activation/is_gated into benchmark routines and FLOPs/bandwidth calculators; test-data generation and weight sizing now handle gated (2*I) vs non-gated (I) FC1 shapes.
Top-level exports
flashinfer/__init__.py, flashinfer/fused_moe/__init__.py, flashinfer/fused_moe/cute_dsl/__init__.py
Conditionally re-export new SM12x CuTe-DSL symbols b12x_fused_moe and B12xMoEWrapper.
FP4 helpers
flashinfer/cute_dsl/fp4_common.py
Added ReLU² FP4 helper functions (relu2_16, relu2_quantize_block_fp4).
SM12x API (new)
flashinfer/fused_moe/cute_dsl/b12x_moe.py
Added b12x_fused_moe() functional API and B12xMoEWrapper (CUDA ≥13 / bf16 output checks, workspace/weight‑view management, activation parameter).
Dispatch & compaction
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py, .../triton_compact.py, .../blackwell_sm12x/__init__.py
Added micro backend and selection logic, threaded activation through launch/compile APIs, weight‑view shapes conditional on gating, Triton compact_topk_ids utility, and exported MoEMicroKernel / sm120_moe_dispatch_context.
Static & Dynamic kernels
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.py, .../moe_dynamic_kernel.py
Kernels accept activation ("silu" or "relu2"), compute is_gated, and adapt FC1 shapes, storage layouts, pipelines, tiling, and activation/quantization math accordingly.
CuteDSL NVFP4 fused MoE
flashinfer/fused_moe/cute_dsl/fused_moe.py, flashinfer/fused_moe/cute_dsl/fused_moe.CuteDslMoEWrapper
Removed SM120/SM121 direct-dispatch from NVFP4 path; NVFP4 wrapper/support reduced to SM100/SM103 only.
Utils
flashinfer/cute_dsl/utils.py
get_max_active_clusters wrapped to catch probe errors, warn, and fall back to device SM count.
GEMM support
flashinfer/gemm/gemm_base.py
Expanded b12x FP4 GEMM support to include SM121 and treat SM120.0/120.1 uniformly.
Tests
tests/moe/test_b12x_fused_moe.py, tests/moe/test_cute_dsl_fused_moe.py
Added comprehensive SM120/121 b12x tests (functional, wrapper, micro/static/dynamic, gated/non‑gated, CUDA-graph). Updated NVFP4 tests to be SM100-only and removed SM12x branches.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant FuncAPI as b12x_fused_moe()
    participant Wrapper as B12xMoEWrapper
    participant Dispatch as launch_sm120_moe()
    participant Selector as BackendSelector
    participant Kernel as Static/Dynamic/MicroKernel

    User->>FuncAPI: call(x, quantized weights, routing, activation)
    FuncAPI->>FuncAPI: validate CUDA>=13 and bf16 output, derive dims
    FuncAPI->>Dispatch: invoke with activation and weight views
    Dispatch->>Dispatch: compute w1_rows from activation (is_gated?)
    Dispatch->>Selector: choose backend by routed rows / tokens / top_k
    alt micro path
        Selector->>Kernel: compact_topk_ids (if needed) then launch micro kernel
    else static path
        Selector->>Kernel: launch static kernel
    else dynamic path
        Selector->>Kernel: launch dynamic kernel
    end
    Kernel->>Kernel: apply activation (SiLU gating or ReLU²), quantize, GEMM2 finalize
    Kernel-->>Dispatch: return output
    Dispatch-->>FuncAPI: return result
    FuncAPI-->>User: emit output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested labels

cute-dsl

Suggested reviewers

  • bkryu
  • yzh119
  • sricketts
  • cyx-6
  • samuellees
  • aleozlx
  • kahyunnam
  • jimmyzho

Poem

🐇 I hopped through kernels, nudged the gate ajar,

SiLU or ReLU², each path now gleams like a star.
Micro, static, dynamic — dispatch finds its tune,
Buffers prepped, tests run, outputs land by noon.
A rabbit cheers: fused MoE, perfectly in tune!

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Description check ⚠️ Warning No pull request description was provided. The required template specifies sections for Description, Related Issues, and checklist items, all of which are missing. Add a comprehensive description including: what changes were made, why they're needed, any related issues, and confirm pre-commit checks and tests are passing.
Title check ❓ Inconclusive The title 'WIP: B12x micro kernel merged' is vague and uses non-descriptive framing. 'WIP' indicates work-in-progress status, while 'merged' is unclear—it's ambiguous whether the PR merges code or describes a merged state. Replace with a clear, descriptive title that summarizes the primary change, e.g., 'Add SM120/SM121 B12x fused MoE support with micro kernel path' or similar.
✅ Passed checks (1 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 80.90% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Blackwell (SM120/SM121) fused MoE kernels, providing both functional and wrapper APIs. Key additions include support for ReLU2 activation, a new 'micro' kernel path optimized for small decode batches with specific MAC tuning ladders, and a Triton-based expert ID compaction pre-pass. Feedback identifies a potential bug in the single-token optimization path where workspace metadata is not updated, which could lead to incorrect results. It also suggests using defined constants instead of hardcoded values for tile sizes to improve maintainability.

Comment on lines +821 to +840
# For a single token, topk_ids already contains global expert IDs
# densely (one expert per top-k slot), so the Triton compaction
# pre-pass is pure overhead and can be skipped.
launch_ids = flat_ids
if num_tokens != 1:
from .triton_compact import compact_topk_ids as _triton_compact_topk_ids

# Run Triton pre-pass to compact global expert IDs to dense local indices
assert flat_ids.numel() <= workspace.compact_topk_ids.numel(), (
f"compact_topk_ids buffer too small: "
f"{workspace.compact_topk_ids.numel()} < {flat_ids.numel()}"
)
compact_ids = workspace.compact_topk_ids[: flat_ids.numel()]
_triton_compact_topk_ids(
flat_ids,
compact_ids,
workspace.weight_expert_ids,
workspace.active_expert_count,
)
launch_ids = compact_ids
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The optimization for num_tokens == 1 in the micro kernel path skips the Triton compaction pre-pass but fails to update workspace.active_expert_count and workspace.weight_expert_ids. If the micro kernel relies on these values (e.g., to bound expert loops or map local indices to global IDs), it will produce incorrect results or access memory out of bounds, especially since active_expert_count is initialized to 0 and weight_expert_ids is an identity map that might not match the global IDs in flat_ids for EP configurations. Furthermore, active_expert_count will persist values from previous calls if not explicitly reset or updated.

# Select micro MAC: min of tuned ladder, work tiles, and hardware limit.
# The hardware cap (base_mac) prevents deadlocks on GPUs with fewer SMs
# than the profiled tuning target.
micro_work_tiles = max(1, routed_rows * max(1, (n + 128 - 1) // 128))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tile size 128 is hardcoded here. It should use the constant _LEVEL_TILE_N defined at line 32 to maintain consistency and allow for easier tuning of the tile dimensions.

Suggested change
micro_work_tiles = max(1, routed_rows * max(1, (n + 128 - 1) // 128))
micro_work_tiles = max(1, routed_rows * max(1, (n + _LEVEL_TILE_N - 1) // _LEVEL_TILE_N))

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
tests/moe/test_cute_dsl_fused_moe.py (1)

57-75: Dead helper after skip-condition rewrite.

_has_cuda_13() is no longer referenced now that the CUDA-13 check was removed from sm100_required. Consider deleting it (and the now-unused import of it elsewhere, if any) to avoid stale code.

Proposed cleanup
-def _has_cuda_13():
-    """Check if CUDA runtime version is 13+."""
-    return (
-        torch.version.cuda is not None and int(torch.version.cuda.split(".")[0]) >= 13
-    )
-
-
 # Skip decorators
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_cute_dsl_fused_moe.py` around lines 57 - 75, Remove the
now-unused helper function _has_cuda_13 and any imports or references to it
elsewhere in the codebase; locate the definition of _has_cuda_13 in the test
module (and any other modules that import it) and delete the function and its
import lines, run the test suite to ensure nothing else depends on it, and
commit the cleanup.
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py (1)

67-86: Minor: buffer-size check for weight_expert_ids is absent.

The docstring states weight_expert_ids must have >= num_unique_experts entries, but the wrapper validates only compact_topk_ids and active_expert_count. If a caller passes an undersized buffer the tl.store(weight_expert_ids_ptr + compact_id, ...) will silently OOB (masked only by valid & first_flags, not by buffer length). Since num_unique_experts isn't known until runtime, a conservative weight_expert_ids.numel() >= min(total_pairs, state_E)-style assert would catch misuse in debug builds.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py` around lines
67 - 86, Add a guard before launching _compact_topk_ids_kernel to validate the
weight_expert_ids buffer is large enough: compute total_pairs = topk_ids.numel()
(already present) and compare weight_expert_ids.numel() against a conservative
bound such as min(total_pairs, state_E) or the expected number of unique experts
from state_E, and raise ValueError if it is smaller; update the validation block
that currently checks compact_topk_ids and active_expert_count so it also checks
weight_expert_ids and includes the same early-return/exception behavior before
calling _compact_topk_ids_kernel.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/routines/moe.py`:
- Around line 1367-1376: The code allows non-gated ActivationType (Relu2) to
fall through into the legacy CuteDSL path which only supports gated activations;
update the guard so that before taking the non-SM120/legacy branch you validate
that activation_type is gated (i.e., is_gated True) and raise a ValueError
otherwise. Specifically, check activation_type/is_gated prior to calling
CuteDslMoEWrapper or cute_dsl_fused_moe_nvfp4 (and before using
_create_cute_dsl_moe_test_data which shrinks w1), and if not gated, raise an
informative error (e.g., "CuTe DSL MoE only supports gated activations
(SwiGLU)") so Relu2 cannot be passed into the SwiGLU-only kernels.

In `@flashinfer/fused_moe/cute_dsl/b12x_moe.py`:
- Around line 215-238: The allocation path in _allocate_buffers uses
select_sm120_moe_backend and can choose a dynamic workspace via
allocate_sm120_dynamic_workspace even when num_local_experts != num_experts,
which later causes launch_sm120_moe to force static and results in OOB writes;
update _allocate_buffers to mirror launch_sm120_moe's EP fallback: if
use_cuda_graph is true or if num_local_experts != num_experts, force choosing
the static backend (call allocate_sm120_static_workspace) instead of dynamic
(allocate_sm120_dynamic_workspace) so preallocated workspace type matches the
launch-time routing logic (affecting B12xMoEWrapper, _workspace,
allocate_sm120_dynamic_workspace, allocate_sm120_static_workspace,
select_sm120_moe_backend, and launch_sm120_moe).

In `@tests/moe/test_b12x_fused_moe.py`:
- Line 359: The test uses torch.zeros(..., device="cuda") which forces the
default CUDA device and can cause cross-device copies or failures; replace
device="cuda" (and any .cuda() calls) with the existing tensor's device (e.g.,
use hidden_states.device or the corresponding tensor.device) when creating
reference tensors like output so they stay on the same device as hidden_states;
apply the same change to the other occurrences listed (lines referencing output,
ref, or other reference tensors at the noted ranges).

---

Nitpick comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py`:
- Around line 67-86: Add a guard before launching _compact_topk_ids_kernel to
validate the weight_expert_ids buffer is large enough: compute total_pairs =
topk_ids.numel() (already present) and compare weight_expert_ids.numel() against
a conservative bound such as min(total_pairs, state_E) or the expected number of
unique experts from state_E, and raise ValueError if it is smaller; update the
validation block that currently checks compact_topk_ids and active_expert_count
so it also checks weight_expert_ids and includes the same early-return/exception
behavior before calling _compact_topk_ids_kernel.

In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 57-75: Remove the now-unused helper function _has_cuda_13 and any
imports or references to it elsewhere in the codebase; locate the definition of
_has_cuda_13 in the test module (and any other modules that import it) and
delete the function and its import lines, run the test suite to ensure nothing
else depends on it, and commit the cleanup.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5ad75d0d-536b-483f-9d9b-c3bd7629024e

📥 Commits

Reviewing files that changed from the base of the PR and between 0e18a1c and 2259f2e.

📒 Files selected for processing (16)
  • benchmarks/routines/moe.py
  • benchmarks/routines/moe_utils.py
  • flashinfer/__init__.py
  • flashinfer/cute_dsl/fp4_common.py
  • flashinfer/fused_moe/__init__.py
  • flashinfer/fused_moe/cute_dsl/__init__.py
  • flashinfer/fused_moe/cute_dsl/b12x_moe.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/__init__.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py
  • flashinfer/fused_moe/cute_dsl/fused_moe.py
  • tests/moe/test_b12x_fused_moe.py
  • tests/moe/test_cute_dsl_fused_moe.py

Comment on lines +1367 to +1376
# Map ActivationType enum to string for SM120 CuTe DSL API
activation_type = args.activation_type
_ACT_STR = {ActivationType.Swiglu: "silu", ActivationType.Relu2: "relu2"}
if activation_type not in _ACT_STR:
raise ValueError(
f"CuTe DSL MoE only supports Swiglu and Relu2 activations, "
f"got {activation_type.name}"
)
activation_str = _ACT_STR[activation_type]
is_gated = activation_type == ActivationType.Swiglu
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reject non-gated activations before falling back to the legacy CuteDSL path.

This now allows Relu2 to reach the non-SM120 branch, but that branch still calls CuteDslMoEWrapper / cute_dsl_fused_moe_nvfp4, which are gated-only. Because _create_cute_dsl_moe_test_data() already shrinks w1 to [E, n, k] when is_gated=False, an SM100/SM103 benchmark run will feed non-gated weights into a SwiGLU-only kernel.

Suggested guard
     activation_str = _ACT_STR[activation_type]
     is_gated = activation_type == ActivationType.Swiglu
+    sm_major_bm = torch.cuda.get_device_capability(device)[0]
+    is_sm120 = sm_major_bm == 12
+    if not is_sm120 and not is_gated:
+        raise ValueError(
+            "Non-gated MoE activations are only supported by the b12x SM120/SM121 backend"
+        )
 
     # Create CuteDSL-specific NVFP4 test data
     tensors = _create_cute_dsl_moe_test_data(
         num_tokens=num_tokens,
         hidden_size=hidden_size,
@@
-    sm_major_bm = torch.cuda.get_device_capability(device)[0]
-    is_sm120 = sm_major_bm == 12
     x_input = tensors["x_bf16"] if is_sm120 else tensors["x"]

Also applies to: 1397-1488, 1532-1557

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/moe.py` around lines 1367 - 1376, The code allows
non-gated ActivationType (Relu2) to fall through into the legacy CuteDSL path
which only supports gated activations; update the guard so that before taking
the non-SM120/legacy branch you validate that activation_type is gated (i.e.,
is_gated True) and raise a ValueError otherwise. Specifically, check
activation_type/is_gated prior to calling CuteDslMoEWrapper or
cute_dsl_fused_moe_nvfp4 (and before using _create_cute_dsl_moe_test_data which
shrinks w1), and if not gated, raise an informative error (e.g., "CuTe DSL MoE
only supports gated activations (SwiGLU)") so Relu2 cannot be passed into the
SwiGLU-only kernels.

Comment thread flashinfer/fused_moe/cute_dsl/b12x_moe.py Outdated
fc2_input_scale: torch.Tensor,
) -> torch.Tensor:
"""Reference ReLU2 MoE: output = relu(FC1(x))^2, then FC2."""
output = torch.zeros(num_tokens, hidden_size, dtype=torch.float32, device="cuda")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Keep the reference path on the tensors’ existing device.

device="cuda" and .cuda() both target the current default GPU, not necessarily the GPU these test tensors already live on. On multi-GPU runners that can turn the reference computation into an implicit cross-device copy or fail with device-mismatch errors. Use hidden_states.device / the existing tensor device instead.

Also applies to: 561-563, 635-637, 741-743, 892-894, 952-954, 1017-1019, 1094-1096, 1155-1157, 1210-1212

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_b12x_fused_moe.py` at line 359, The test uses torch.zeros(...,
device="cuda") which forces the default CUDA device and can cause cross-device
copies or failures; replace device="cuda" (and any .cuda() calls) with the
existing tensor's device (e.g., use hidden_states.device or the corresponding
tensor.device) when creating reference tensors like output so they stay on the
same device as hidden_states; apply the same change to the other occurrences
listed (lines referencing output, ref, or other reference tensors at the noted
ranges).

Andrii Skliar and others added 2 commits April 17, 2026 08:52
Skip the Triton compaction prepass for single-token micro decode by
using the raw topk_ids directly as global expert indices in-kernel.
A single token's top-k routing is already a dense local expert set,
so the id-compaction launch is pure overhead.

Ported from lukealonso/b12x@99c0a80.

~50% latency reduction on bs=1 relu2 decode (313 us -> 157 us).

Co-Authored-By: Luke Alonso <lalonso@gmail.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tests/moe/test_b12x_fused_moe.py (2)

42-47: is_sm120_family() only inspects device 0.

get_device_properties(0) always looks at GPU 0, which may not match where the test tensors run (and not all CI hosts are guaranteed to have device 0 as the active device). This is only used to choose w1_bf16_prepared in create_moe_tensors, and every test class is already gated by @sm120_required (which queries torch.device("cuda")), so the check is effectively redundant — consider deleting the branch or using torch.cuda.current_device().

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_b12x_fused_moe.py` around lines 42 - 47, The helper
is_sm120_family() inspects GPU 0 unconditionally which can mismatch the active
device; update it to use torch.cuda.current_device() (i.e., call
torch.cuda.get_device_properties(torch.cuda.current_device())) or remove the
branch entirely since tests are already gated by `@sm120_required`; locate the
is_sm120_family function and either replace get_device_properties(0) with
get_device_properties(torch.cuda.current_device()) or delete the function/branch
and adjust create_moe_tensors to rely on the existing `@sm120_required` gating for
selecting w1_bf16_prepared.

146-148: Implicit Optional type hints — minor.

Ruff flags RUF013 at line 147. fc2_input_scale, num_local_experts, and (in compute_reference_moe_relu2) other defaultable tensor params use bare types with = None. Prefer explicit Optional[...] / ... | None for clarity:

🔧 Proposed fix
-    fc2_input_scale: torch.Tensor = None,
-    num_local_experts: int = None,
+    fc2_input_scale: Optional[torch.Tensor] = None,
+    num_local_experts: Optional[int] = None,
     local_expert_offset: int = 0,

(requires from typing import Optional)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_b12x_fused_moe.py` around lines 146 - 148, Change bare types
that default to None to explicit optional types: update parameters like
fc2_input_scale: torch.Tensor = None and num_local_experts: int = None to use
Optional[torch.Tensor] and Optional[int] (or torch.Tensor | None and int | None)
in the function signatures (e.g., in compute_reference_moe_relu2 and the
surrounding test helpers); if you choose Optional, add from typing import
Optional at the top of the file. Ensure all other tensor/int params with = None
defaults in the same file are updated consistently.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Around line 42-49: _MICRO_MAC_LADDER currently ends at routed_rows=20 causing
_lookup_mac_ladder(_MICRO_MAC_LADDER, routed_rows) to return None for 20 <
routed_rows <= 40 (multi-topk micro cutover), so tuned MACs aren't applied even
though the micro path runs; modify the ladder to cover up to 40 (add entries for
routed_rows between 24–40 matching the tuning profile) or change
_lookup_mac_ladder to return the nearest lower rung instead of None, and ensure
callers (where base_mac is used as fallback) apply the returned tuned MACs when
top_k > 1; key symbols to update are _MICRO_MAC_LADDER and _lookup_mac_ladder
and any logic that currently falls back to base_mac for the micro path.

---

Nitpick comments:
In `@tests/moe/test_b12x_fused_moe.py`:
- Around line 42-47: The helper is_sm120_family() inspects GPU 0 unconditionally
which can mismatch the active device; update it to use
torch.cuda.current_device() (i.e., call
torch.cuda.get_device_properties(torch.cuda.current_device())) or remove the
branch entirely since tests are already gated by `@sm120_required`; locate the
is_sm120_family function and either replace get_device_properties(0) with
get_device_properties(torch.cuda.current_device()) or delete the function/branch
and adjust create_moe_tensors to rely on the existing `@sm120_required` gating for
selecting w1_bf16_prepared.
- Around line 146-148: Change bare types that default to None to explicit
optional types: update parameters like fc2_input_scale: torch.Tensor = None and
num_local_experts: int = None to use Optional[torch.Tensor] and Optional[int]
(or torch.Tensor | None and int | None) in the function signatures (e.g., in
compute_reference_moe_relu2 and the surrounding test helpers); if you choose
Optional, add from typing import Optional at the top of the file. Ensure all
other tensor/int params with = None defaults in the same file are updated
consistently.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 838fe31b-ccd7-4a39-9799-72d8b1f33a23

📥 Commits

Reviewing files that changed from the base of the PR and between 2259f2e and afadf59.

📒 Files selected for processing (6)
  • flashinfer/fused_moe/cute_dsl/b12x_moe.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
  • flashinfer/fused_moe/cute_dsl/fused_moe.py
  • tests/moe/test_b12x_fused_moe.py
  • tests/moe/test_cute_dsl_fused_moe.py
✅ Files skipped from review due to trivial changes (1)
  • flashinfer/fused_moe/cute_dsl/b12x_moe.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/moe/test_cute_dsl_fused_moe.py

Comment on lines +42 to +49
_MICRO_MAC_LADDER: Tuple[Tuple[int, int], ...] = (
(2, 84),
(4, 127),
(8, 107),
(10, 84),
(16, 63),
(20, 84),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

_MICRO_MAC_LADDER stops at routed_rows=20 but the multi-topk micro cutover is 40.

When top_k > 1 and 20 < routed_rows <= 40, _lookup_mac_ladder(_MICRO_MAC_LADDER, routed_rows) returns None and line 870 falls back to the hardware base_mac. The micro path still runs, but the tuned MAC values are not applied for a band of configurations the cutover explicitly enables. Consider extending the ladder past 20 (or documenting that the tail range is intentionally untuned).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines
42 - 49, _MICRO_MAC_LADDER currently ends at routed_rows=20 causing
_lookup_mac_ladder(_MICRO_MAC_LADDER, routed_rows) to return None for 20 <
routed_rows <= 40 (multi-topk micro cutover), so tuned MACs aren't applied even
though the micro path runs; modify the ladder to cover up to 40 (add entries for
routed_rows between 24–40 matching the tuning profile) or change
_lookup_mac_ladder to return the nearest lower rung instead of None, and ensure
callers (where base_mac is used as fallback) apply the returned tuned MACs when
top_k > 1; key symbols to update are _MICRO_MAC_LADDER and _lookup_mac_ladder
and any logic that currently falls back to base_mac for the micro path.

- Added a context manager for tagging the current dispatch stage in `moe_dispatch.py`.
- Updated `moe_micro_kernel.py` to use the global expert ID from the compact routing map, improving output accuracy.
- Refactored variable names for clarity and consistency in handling tile views.

This change aims to improve the robustness of the MoE micro-kernel operations and streamline the dispatch context management.

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
@askliar askliar force-pushed the b12x_micro_kernel_merged branch from afadf59 to 122ac7f Compare April 17, 2026 07:10
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (5)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py (4)

850-881: Micro num_tokens == 1 host-fill path: verify weight_expert_ids slice semantics.

On this branch, workspace.weight_expert_ids[: flat_ids.numel()].copy_(flat_ids.to(torch.int32)) assumes flat_ids.numel() == top_k ≤ state_E (workspace allocation size). Today this always holds because top_k ≤ num_experts == num_local_experts on SM120 (EP unsupported per prior learnings), so the slice length equals flat_ids.numel() and copy_ succeeds. If a future refactor ever grows top_k beyond state_E, the slice would silently truncate to state_E and copy_ would raise a size-mismatch error. A torch._assert(flat_ids.numel() <= workspace.weight_expert_ids.numel(), ...) near the existing compact_topk_ids assert would make the invariant explicit. Optional — non-blocking.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines
850 - 881, The host-path for use_micro and num_tokens == 1 assumes
workspace.weight_expert_ids has at least flat_ids.numel() capacity which could
silently truncate or later raise; add an explicit assertion like
torch._assert(flat_ids.numel() <= workspace.weight_expert_ids.numel(), ...)
alongside the existing assert for workspace.compact_topk_ids to make the
invariant explicit and fail early; update the block around the use_micro /
num_tokens == 1 branch referencing flat_ids, workspace.compact_topk_ids, and
workspace.weight_expert_ids to include this check before performing
workspace.weight_expert_ids[: flat_ids.numel()].copy_(...).

71-85: sm120_moe_dispatch_context is not thread-safe.

The context manager mutates the module-level _CURRENT_DISPATCH_STAGE without a lock or contextvars.ContextVar. Concurrent use from multiple Python threads (or overlapping async tasks) would interleave stage tags. Given Python's GIL plus the typical single-dispatch-thread usage in inference servers, this is unlikely to bite in practice, but if downstream consumers branch on this value during capture/replay of CUDA graphs from worker threads, switching to contextvars.ContextVar would make the tagging correct by construction.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines
71 - 85, The module-level mutable _CURRENT_DISPATCH_STAGE is not thread-safe;
replace it with a contextvars.ContextVar (e.g., CURRENT_DISPATCH_STAGE:
ContextVar[str | None] = ContextVar("CURRENT_DISPATCH_STAGE", default=None)) and
update sm120_moe_dispatch_context to use ContextVar.set() and restore via the
returned token in the finally block (token = CURRENT_DISPATCH_STAGE.set(stage);
finally: CURRENT_DISPATCH_STAGE.reset(token)); update any reads of
_CURRENT_DISPATCH_STAGE to call CURRENT_DISPATCH_STAGE.get() and remove the
global mutation.

565-789: _get_micro_kernel duplicates ~95% of _get_static_kernel's fake-tensor plumbing.

The two builders differ only in (a) the kernel class (MoEMicroKernel vs MoEStaticKernel), (b) the share_input_across_experts kwarg, (c) the cache dict and key tag, and (d) always selecting the narrow tiler via _select_moe_mma_tiler_mn for micro. Every change to the compiled argument list or fake-tensor shapes now has to be mirrored in two places. A small helper that builds the fake tensors once and is shared by both builders would eliminate this drift hazard. Non-blocking for this PR since the code is correct, but worth a follow-up.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines
565 - 789, The _get_micro_kernel implementation duplicates nearly all
fake-tensor construction from _get_static_kernel; refactor by extracting a
shared helper (e.g., _build_moe_fake_tensors) that returns the tuple/dict of all
fake tensors and common computed values (rows_pad_k, cols_pad_k,
topk_ids_cutlass_dtype, topk_ids_align, etc.), then call that helper from both
_get_micro_kernel and _get_static_kernel and only apply the small differences
(kernel class selection: MoEMicroKernel vs MoEStaticKernel,
share_input_across_experts flag, cache key tag, and the micro-only call to
_select_moe_mma_tiler_mn) before calling cute.compile; ensure the helper returns
values in the same order used in the compiled argument list so you can pass them
unchanged to cute.compile in both functions and update usages of make_ptr and
cute.runtime.make_fake_compact_tensor inside the helper.

160-207: compact_topk_ids buffer sized to max_rows — correct, but fragile if max_rows semantics ever shift.

compact_topk_ids is allocated with max_rows entries, which currently equals routed_rows = num_tokens * top_k at workspace-allocation time (via _get_cached_workspace(..., max_rows=max(1, routed_rows))). The runtime assert at Line 851 (flat_ids.numel() <= workspace.compact_topk_ids.numel()) correctly guards this, and cache reuse growth is monotonic. Consider adding a brief docstring note on this field that its capacity must track routed_rows (not num_tokens) to prevent future regressions when someone resizes workspace dimensions.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines
160 - 207, The compact_topk_ids buffer in Sm120StaticMoEWorkspace is currently
allocated to max_rows (set from routed_rows) but lacks documentation warning
future maintainers that its capacity must track routed_rows (num_tokens * top_k)
rather than num_tokens; update the allocate_sm120_static_workspace function or
the Sm120StaticMoEWorkspace class docstring to state that compact_topk_ids
capacity must be >= routed_rows and reference the cache allocation path
(_get_cached_workspace) and the runtime check that compares flat_ids.numel() to
workspace.compact_topk_ids.numel() so future changes to max_rows semantics won’t
cause silent regressions.
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py (1)

51-86: Consider also validating weight_expert_ids capacity.

compact_topk_ids and active_expert_count are shape-checked, but weight_expert_ids is indexed at compact_id ∈ [0, num_unique_experts) without an explicit bound check. Callers currently pass workspace.weight_expert_ids sized to state_E, which is an upper bound on unique experts, so today it's safe. Adding a cheap weight_expert_ids.numel() >= min(total_pairs, ... ) guard (or just >= total_pairs as a conservative bound) would make the contract robust against future callers. Optional — feel free to leave as-is if the docstring is considered sufficient.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py` around lines
51 - 86, The function compact_topk_ids does not validate that weight_expert_ids
has enough capacity for the kernel to write the dense local->global map; add a
cheap runtime check at the top of compact_topk_ids that ensures
weight_expert_ids.numel() is large enough (e.g., weight_expert_ids.numel() >=
total_pairs or another conservative upper bound such as the expected state_E)
and raise a ValueError with a clear message if not, so callers supplying
workspace.weight_expert_ids with insufficient size fail early.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Around line 1689-1694: Add an explicit validation for the activation string to
avoid silently treating unknown activations as non‑gated: in the block around
is_gated = activation == "silu" (used by b12x_moe), assert or raise if
activation is not one of the supported values (e.g. "silu" or "relu2"); then
compute intermediate_size based on is_gated as before. Reference the symbols
activation, is_gated, intermediate_size, w1_weight and the public caller
b12x_moe so the check is colocated with the current logic.

---

Nitpick comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Around line 850-881: The host-path for use_micro and num_tokens == 1 assumes
workspace.weight_expert_ids has at least flat_ids.numel() capacity which could
silently truncate or later raise; add an explicit assertion like
torch._assert(flat_ids.numel() <= workspace.weight_expert_ids.numel(), ...)
alongside the existing assert for workspace.compact_topk_ids to make the
invariant explicit and fail early; update the block around the use_micro /
num_tokens == 1 branch referencing flat_ids, workspace.compact_topk_ids, and
workspace.weight_expert_ids to include this check before performing
workspace.weight_expert_ids[: flat_ids.numel()].copy_(...).
- Around line 71-85: The module-level mutable _CURRENT_DISPATCH_STAGE is not
thread-safe; replace it with a contextvars.ContextVar (e.g.,
CURRENT_DISPATCH_STAGE: ContextVar[str | None] =
ContextVar("CURRENT_DISPATCH_STAGE", default=None)) and update
sm120_moe_dispatch_context to use ContextVar.set() and restore via the returned
token in the finally block (token = CURRENT_DISPATCH_STAGE.set(stage); finally:
CURRENT_DISPATCH_STAGE.reset(token)); update any reads of
_CURRENT_DISPATCH_STAGE to call CURRENT_DISPATCH_STAGE.get() and remove the
global mutation.
- Around line 565-789: The _get_micro_kernel implementation duplicates nearly
all fake-tensor construction from _get_static_kernel; refactor by extracting a
shared helper (e.g., _build_moe_fake_tensors) that returns the tuple/dict of all
fake tensors and common computed values (rows_pad_k, cols_pad_k,
topk_ids_cutlass_dtype, topk_ids_align, etc.), then call that helper from both
_get_micro_kernel and _get_static_kernel and only apply the small differences
(kernel class selection: MoEMicroKernel vs MoEStaticKernel,
share_input_across_experts flag, cache key tag, and the micro-only call to
_select_moe_mma_tiler_mn) before calling cute.compile; ensure the helper returns
values in the same order used in the compiled argument list so you can pass them
unchanged to cute.compile in both functions and update usages of make_ptr and
cute.runtime.make_fake_compact_tensor inside the helper.
- Around line 160-207: The compact_topk_ids buffer in Sm120StaticMoEWorkspace is
currently allocated to max_rows (set from routed_rows) but lacks documentation
warning future maintainers that its capacity must track routed_rows (num_tokens
* top_k) rather than num_tokens; update the allocate_sm120_static_workspace
function or the Sm120StaticMoEWorkspace class docstring to state that
compact_topk_ids capacity must be >= routed_rows and reference the cache
allocation path (_get_cached_workspace) and the runtime check that compares
flat_ids.numel() to workspace.compact_topk_ids.numel() so future changes to
max_rows semantics won’t cause silent regressions.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py`:
- Around line 51-86: The function compact_topk_ids does not validate that
weight_expert_ids has enough capacity for the kernel to write the dense
local->global map; add a cheap runtime check at the top of compact_topk_ids that
ensures weight_expert_ids.numel() is large enough (e.g.,
weight_expert_ids.numel() >= total_pairs or another conservative upper bound
such as the expected state_E) and raise a ValueError with a clear message if
not, so callers supplying workspace.weight_expert_ids with insufficient size
fail early.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 3ef15b7c-633d-48c5-978d-73971bc16e78

📥 Commits

Reviewing files that changed from the base of the PR and between afadf59 and 122ac7f.

📒 Files selected for processing (5)
  • flashinfer/cute_dsl/utils.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/__init__.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/init.py

Comment on lines 1689 to 1694
num_tokens = topk_ids.size(0)
k = a.size(1) # hidden_size
intermediate_size = w1_weight.size(1) // 2
is_gated = activation == "silu"
# w1_weight.size(1) is 2*n for gated or n for non-gated
intermediate_size = w1_weight.size(1) // 2 if is_gated else w1_weight.size(1)
n = intermediate_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Silent fallthrough for unknown activation values.

is_gated = activation == "silu" silently treats any value that is not exactly "silu" (including typos or future additions) as non-gated, so intermediate_size = w1_weight.size(1) will mis-derive n for unrecognized strings. The single public caller (b12x_moe) documents only "silu" and "relu2", so nothing is broken today, but a cheap validation (assert activation in ("silu", "relu2") here or in b12x_moe) would turn future misuse into an immediate error instead of a scale-factor mismatch at launch time.

🛡️ Optional guard
     num_tokens = topk_ids.size(0)
     k = a.size(1)  # hidden_size
+    if activation not in ("silu", "relu2"):
+        raise ValueError(
+            f"Unsupported activation {activation!r}; expected 'silu' or 'relu2'."
+        )
     is_gated = activation == "silu"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
num_tokens = topk_ids.size(0)
k = a.size(1) # hidden_size
intermediate_size = w1_weight.size(1) // 2
is_gated = activation == "silu"
# w1_weight.size(1) is 2*n for gated or n for non-gated
intermediate_size = w1_weight.size(1) // 2 if is_gated else w1_weight.size(1)
n = intermediate_size
num_tokens = topk_ids.size(0)
k = a.size(1) # hidden_size
if activation not in ("silu", "relu2"):
raise ValueError(
f"Unsupported activation {activation!r}; expected 'silu' or 'relu2'."
)
is_gated = activation == "silu"
# w1_weight.size(1) is 2*n for gated or n for non-gated
intermediate_size = w1_weight.size(1) // 2 if is_gated else w1_weight.size(1)
n = intermediate_size
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines
1689 - 1694, Add an explicit validation for the activation string to avoid
silently treating unknown activations as non‑gated: in the block around is_gated
= activation == "silu" (used by b12x_moe), assert or raise if activation is not
one of the supported values (e.g. "silu" or "relu2"); then compute
intermediate_size based on is_gated as before. Reference the symbols activation,
is_gated, intermediate_size, w1_weight and the public caller b12x_moe so the
check is colocated with the current logic.

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
bkryu and others added 5 commits April 17, 2026 11:26
…es, Spark MAC cap

From lukealonso/b12x commits a103209d and 5043ca40 (AI-assisted):

moe_micro_kernel.py:
- Add single_token and share_expert_scales compile-time flags to
  _MoEMicroKernelBase. Replaces all runtime all_rows_unique variable
  checks with cutlass.const_expr(self.single_token) to eliminate
  dead code paths at JIT compile time.
- Replace shared_single_input runtime int with
  cutlass.const_expr(self.share_input_across_experts).
- Add scale_idx conditional in both quant phases: when share_expert_scales
  is True, use index 0 instead of per-expert lookup.

moe_dispatch.py:
- Add _get_relu2_bs1_spark_micro_cap() returning 42 (non-monotonic optimum
  on DGX Spark measured in b12x profiling).
- Cap micro MAC at 42 for relu2 bs=1 on DGX Spark (sm_count <= 96).
- Pass share_expert_scales=(both input scales are scalar) and
  single_token=(num_tokens==1) to _get_micro_kernel.

dense_blockscaled_gemm_sm120.py:
- Add single_work_tile_per_cta flag: skips tile_sched.advance_to_next_work()
  (global atomic) when each CTA handles exactly one tile, reducing
  overhead on small-M wide-N decode shapes.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/cute_dsl/utils.py (1)

117-147: ⚠️ Potential issue | 🟡 Minor

Avoid caching transient fallback results and narrow the caught exception.

Line 129 catches every Exception, and because this function is cached with @functools.cache, the first transient probe failure permanently caches sm_count for that cluster_size. Subsequent calls to the same function with the same cluster_size will never retry the probe, even if the CUDA driver context becomes available later. Additionally, the broad exception catch risks hiding unexpected errors unrelated to the CUDA driver context issue. Split successful probe caching into a helper so fallback calls can retry, and narrow the exception to RuntimeError, which is what CUDA driver API calls raise.

Suggested refactor
+@functools.cache
+def _get_max_active_clusters_from_hardware(cluster_size: int) -> int:
+    return get_hardware_info().get_max_active_clusters(cluster_size)
+
+
-@functools.cache
 def get_max_active_clusters(cluster_size: int) -> int:
     """Get max active clusters for a given cluster size (cached).
     
     Args:
         cluster_size: Product of cluster_shape_mn dimensions.
     
     Returns:
         Maximum number of active clusters supported by hardware.
     """
     try:
-        return get_hardware_info().get_max_active_clusters(cluster_size)
-    except Exception as exc:
+        return _get_max_active_clusters_from_hardware(cluster_size)
+    except RuntimeError as exc:
         # nvidia_cutlass_dsl's hardware probe (cuKernelGetFunction) can fail
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/utils.py` around lines 117 - 147, The cached
get_max_active_clusters currently catches all Exception and caches the fallback
value, so transient probe failures get permanently stored; refactor by moving
the actual probe into a non-cached helper (e.g.,
_probe_max_active_clusters(cluster_size) that calls
get_hardware_info().get_max_active_clusters(cluster_size)), make
get_max_active_clusters (still decorated with functools.cache) call the helper
and only cache successful probe results, and narrow the except clause to catch
RuntimeError (the expected CUDA driver failure) so fallback to
get_num_sm(torch.device("cuda")) is used without being cached; ensure the helper
raises on unexpected exceptions so they surface.
♻️ Duplicate comments (1)
tests/moe/test_b12x_fused_moe.py (1)

385-385: ⚠️ Potential issue | 🟡 Minor

Avoid default-device CUDA placement in reference tensors.

device="cuda" and .cuda() target the current default GPU, which can mismatch the tensors’ actual device on multi-GPU runners. Keep reference tensors on the source tensor device.

Suggested fix
-    output = torch.zeros(num_tokens, hidden_size, dtype=torch.float32, device="cuda")
+    output = torch.zeros(
+        num_tokens,
+        hidden_size,
+        dtype=torch.float32,
+        device=hidden_states.device,
+    )
-        ref_output = compute_reference_moe_fp4(
-            hidden_states=tensors["x_bf16"].float().cuda(),
-            gemm1_weights=tensors["w1_weight_bf16"].float().cuda(),
-            gemm2_weights=tensors["w2_weight_bf16"].float().cuda(),
+        ref_output = compute_reference_moe_fp4(
+            hidden_states=tensors["x_bf16"].float().to(tensors["x_bf16"].device),
+            gemm1_weights=tensors["w1_weight_bf16"].float().to(tensors["x_bf16"].device),
+            gemm2_weights=tensors["w2_weight_bf16"].float().to(tensors["x_bf16"].device),
             ...
         )
-        ref_output = compute_reference_moe_relu2(
-            hidden_states=tensors["x_bf16"].float().cuda(),
-            fc1_weights=tensors["w1_weight_bf16"].float().cuda(),
-            fc2_weights=tensors["w2_weight_bf16"].float().cuda(),
+        ref_output = compute_reference_moe_relu2(
+            hidden_states=tensors["x_bf16"].float().to(tensors["x_bf16"].device),
+            fc1_weights=tensors["w1_weight_bf16"].float().to(tensors["x_bf16"].device),
+            fc2_weights=tensors["w2_weight_bf16"].float().to(tensors["x_bf16"].device),
             ...
         )

Also applies to: 589-591, 665-667, 771-773, 926-928, 986-988, 1051-1053, 1130-1132, 1191-1193, 1246-1248

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_b12x_fused_moe.py` at line 385, Reference tensors (e.g., the
`output` tensor created as torch.zeros(num_tokens, hidden_size,
dtype=torch.float32, device="cuda") in the test function) are being placed on
the default CUDA device; change these to be allocated on the source/input
tensor's device instead (use the device of the input tensor, e.g., input.device
or tokens.device) and remove any `.cuda()` calls so the reference tensors follow
the actual tensor under test; apply the same fix for the other similar reference
allocations noted (the other `torch.zeros`/`.cuda()` usages around the file).
🧹 Nitpick comments (2)
tests/moe/test_b12x_fused_moe.py (2)

158-160: Use explicit Optional typing for nullable tensor parameters.

fc2_input_scale: torch.Tensor = None and num_local_experts: int = None should be explicitly annotated as nullable for type-checker clarity.

Suggested fix
+from typing import Optional
...
-    fc2_input_scale: torch.Tensor = None,
-    num_local_experts: int = None,
+    fc2_input_scale: Optional[torch.Tensor] = None,
+    num_local_experts: Optional[int] = None,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_b12x_fused_moe.py` around lines 158 - 160, Update the function
signature to use explicit nullable types: change the annotations for
fc2_input_scale and num_local_experts to Optional[torch.Tensor] and
Optional[int] respectively (keeping their default = None), and add the
corresponding import from typing (Optional) at the top of the file; locate the
signature that currently lists fc2_input_scale: torch.Tensor = None,
num_local_experts: int = None (within tests/moe/test_b12x_fused_moe.py) and
replace those types so static type checkers correctly recognize the nullable
parameters.

64-64: Narrow the broad exception in CUDA-version gating.

except Exception can mask unrelated failures and silently skip tests. Catch expected error types (import/runtime for CUDA query) and let unexpected ones surface.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_b12x_fused_moe.py` at line 64, The test's broad "except
Exception" in the CUDA-version gating should be narrowed to only catch expected
import/runtime errors: replace "except Exception" with "except (ImportError,
RuntimeError, OSError)" (or the specific exceptions raised by your CUDA/version
check) and ensure any other exceptions are allowed to propagate (either remove
the broad handler or re-raise unexpected exceptions). Update the except block in
tests/moe/test_b12x_fused_moe.py around the CUDA gating logic so only those
specific exception types are swallowed and all other failures surface.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/cute_dsl/utils.py`:
- Around line 117-147: The cached get_max_active_clusters currently catches all
Exception and caches the fallback value, so transient probe failures get
permanently stored; refactor by moving the actual probe into a non-cached helper
(e.g., _probe_max_active_clusters(cluster_size) that calls
get_hardware_info().get_max_active_clusters(cluster_size)), make
get_max_active_clusters (still decorated with functools.cache) call the helper
and only cache successful probe results, and narrow the except clause to catch
RuntimeError (the expected CUDA driver failure) so fallback to
get_num_sm(torch.device("cuda")) is used without being cached; ensure the helper
raises on unexpected exceptions so they surface.

---

Duplicate comments:
In `@tests/moe/test_b12x_fused_moe.py`:
- Line 385: Reference tensors (e.g., the `output` tensor created as
torch.zeros(num_tokens, hidden_size, dtype=torch.float32, device="cuda") in the
test function) are being placed on the default CUDA device; change these to be
allocated on the source/input tensor's device instead (use the device of the
input tensor, e.g., input.device or tokens.device) and remove any `.cuda()`
calls so the reference tensors follow the actual tensor under test; apply the
same fix for the other similar reference allocations noted (the other
`torch.zeros`/`.cuda()` usages around the file).

---

Nitpick comments:
In `@tests/moe/test_b12x_fused_moe.py`:
- Around line 158-160: Update the function signature to use explicit nullable
types: change the annotations for fc2_input_scale and num_local_experts to
Optional[torch.Tensor] and Optional[int] respectively (keeping their default =
None), and add the corresponding import from typing (Optional) at the top of the
file; locate the signature that currently lists fc2_input_scale: torch.Tensor =
None, num_local_experts: int = None (within tests/moe/test_b12x_fused_moe.py)
and replace those types so static type checkers correctly recognize the nullable
parameters.
- Line 64: The test's broad "except Exception" in the CUDA-version gating should
be narrowed to only catch expected import/runtime errors: replace "except
Exception" with "except (ImportError, RuntimeError, OSError)" (or the specific
exceptions raised by your CUDA/version check) and ensure any other exceptions
are allowed to propagate (either remove the broad handler or re-raise unexpected
exceptions). Update the except block in tests/moe/test_b12x_fused_moe.py around
the CUDA gating logic so only those specific exception types are swallowed and
all other failures surface.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 288d3502-cad6-42ce-a34f-8b8b20e15397

📥 Commits

Reviewing files that changed from the base of the PR and between aa649ed and 7310ade.

📒 Files selected for processing (2)
  • flashinfer/cute_dsl/utils.py
  • tests/moe/test_b12x_fused_moe.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants