Skip to content

chore: cute dsl nvfp4 moe clean up#2775

Merged
aleozlx merged 11 commits intoflashinfer-ai:mainfrom
nv-yunzheq:cute_dsl_moe_clean_up
Mar 19, 2026
Merged

chore: cute dsl nvfp4 moe clean up#2775
aleozlx merged 11 commits intoflashinfer-ai:mainfrom
nv-yunzheq:cute_dsl_moe_clean_up

Conversation

@nv-yunzheq
Copy link
Collaborator

@nv-yunzheq nv-yunzheq commented Mar 12, 2026

📌 Description

This is to clean up cute dsl nvfp4 moe

  1. Remove incorrect statement in code that sm110 is supported
  2. Add cute dsl moe to benchmark script
  3. Adjust autouning strategy and tactics. And use unit test to test all tactics
  4. Remove unused blockscaled fp4 grouped gemm in moe
  5. Make pdl as a function parameter instead of environment variable to align with the rest of library
  6. Add CC support decorator to cute dsl moe function

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added CuteDSL FP4 block-scaled MoE benchmark and end-to-end test path; PDL (prefetch-dependent launch) is now runtime-configurable.
  • Bug Fixes

    • Disabled the problematic tile-size=256 tactic.
    • Clarified supported Blackwell architectures to SM100 family (SM100/SM103).
  • Refactor

    • Simplified MoE tactic selection to runtime validation with deterministic tactic shapes.
    • Replaced environment-based PDL flag with a runtime parameter.
  • Tests

    • Added comprehensive MoE tactic validation tests and benchmark samples (some entries duplicated).

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 12, 2026

📝 Walkthrough

Walkthrough

Threads a per-instance enable_pdl flag through CuteDSL MoE runners and kernels, removes several legacy Blackwell kernel modules and a global TRTLLM_ENABLE_PDL flag, adds a new cute_dsl_fp4_block_scale_moe benchmark/test path, and updates tuners, tests, and benchmark metadata.

Changes

Cohort / File(s) Summary
Benchmarks
benchmarks/routines/flashinfer_benchmark_utils.py, benchmarks/routines/moe.py, benchmarks/bench_moe_deepseek.py, benchmarks/samples/sample_testlist.txt
Register new cute_dsl_fp4_block_scale_moe benchmark; add test dispatch and helpers for CuteDSL NVFP4 MoE; update SM10x docstrings/text; duplicated sample_testlist entries added.
Core CuteDSL MoE & Runner
flashinfer/fused_moe/cute_dsl/fused_moe.py, flashinfer/fused_moe/cute_dsl/tuner.py
Add enable_pdl: bool = True across core impls, wrapper API, and runner init/forward; tighten tactic generation/validation and thread flag into forward and kernel calls; adjust prealloc logic for CUDA graphs; add supported_compute_capability([100,103]) annotations.
PDL plumbing / wrappers
flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py, flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py, flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py
Add enable_pdl param to _get_compiled_* cache keys and public wrappers; propagate to kernel construction/invocation; remove flashinfer_api decorators; update SM100-family error messages.
Blackwell kernel removal
flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py, flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm.py, flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py
Delete large Blackwell/SM100 kernel modules and high-level NVFP4 grouped-GEMM / SwiGLU wrappers and related utilities (many public symbols removed).
Blackwell init & utils cleanup
flashinfer/fused_moe/cute_dsl/blackwell/__init__.py, flashinfer/fused_moe/cute_dsl/blackwell/utils.py
Remove re-exports of removed kernels and drop module-level TRTLLM_ENABLE_PDL env flag.
Blackwell kernel class update
flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
Add enable_pdl constructor arg, store self.enable_pdl, and use it for kernel launch use_pdl instead of global flag.
JIT / nvcc flags
flashinfer/jit/moe_utils.py
Restrict supported NVCC major versions to [10] for JIT compilation context.
Tests
tests/moe/test_cute_dsl_fused_moe.py
Add an “all valid tactics” test that iterates tactics vs. reference (duplicated test class present); update is_sm100_family docstring wording to SM10x.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Wrapper as CuteDslMoEWrapper
    participant Tuner as CuteDslFusedMoENvfp4Runner
    participant GEMM1 as GEMM1 Kernel
    participant GEMM2 as GEMM2 Finalize Kernel

    User->>Wrapper: cute_dsl_fused_moe_nvfp4(..., enable_pdl=True)
    Wrapper->>Tuner: init(..., enable_pdl=True)
    Tuner->>Tuner: store self.enable_pdl
    Wrapper->>Tuner: forward(inputs, enable_pdl=self.enable_pdl)
    Tuner->>Tuner: validate/get_valid_tactics()
    Tuner->>GEMM1: launch(gemm1_args..., enable_pdl=self.enable_pdl)
    GEMM1-->>Tuner: intermediate output
    Tuner->>GEMM2: launch(finalize_args..., enable_pdl=self.enable_pdl)
    GEMM2-->>Tuner: final output
    Tuner-->>Wrapper: result
    Wrapper-->>User: MoE output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested labels

cute-dsl

Suggested reviewers

  • yzh119
  • cyx-6
  • jiahanc
  • jimmyzho
  • yongwww
  • Anerudhan
  • bkryu
  • kahyunnam

Poem

🐇 I hop through kernels, tiny and spry,
enable_pdl now set per try,
Old modules rest beneath the hill,
CuteDSL hums — FP4 thrill,
Benchmarks cheer — a carrot pie!

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The PR title 'chore: cute dsl nvfp4 moe clean up' clearly summarizes the main objective of cleanup and refactoring work on the CuteDSL NVFP4 MoE implementation.
Description check ✅ Passed The PR description comprehensively covers all major changes: SM110 removal, benchmark additions, autotune strategy adjustments, removal of unused code, PDL parameter changes, and CC support decorator additions.
Docstring Coverage ✅ Passed Docstring coverage is 87.50% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can use OpenGrep to find security vulnerabilities and bugs across 17+ programming languages.

OpenGrep is compatible with Semgrep configurations. Add an opengrep.yml or semgrep.yml configuration file to your project to enable OpenGrep analysis.

@nv-yunzheq
Copy link
Collaborator Author

/bot run

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on refining and improving the CuteDSL NVFP4 Mixture-of-Experts (MoE) implementation, primarily targeting the Blackwell architecture. The changes include correcting supported compute capabilities, integrating the MoE kernels into the benchmarking system, and enhancing the autotuning mechanism with more robust tactic validation and explicit parameterization for features like Programmatic Dependent Launch (PDL). These updates aim to improve the correctness, testability, and configurability of the CuteDSL MoE kernels.

Highlights

  • Removed SM110 Support Claim: Corrected documentation and code to accurately state that CuteDSL MoE NVFP4 kernels are optimized for SM100 and SM103, removing the incorrect mention of SM110.
  • Integrated CuteDSL MoE Benchmarks: Added the "cute_dsl_fp4_block_scale_moe" routine to the benchmark script and its compute capability mapping, along with sample test cases.
  • Refined Autotuning Strategy: Updated the autotuning logic for CuteDSL MoE, including a fix for tile_size=256 in GEMM1, and implemented can_implement checks to filter valid tactics based on problem dimensions.
  • Removed Unused Kernels: Eliminated the blockscaled_contiguous_grouped_gemm and blockscaled_contiguous_grouped_gemm_swiglu_fusion kernels and their related imports, streamlining the codebase.
  • Parameterized PDL Enablement: Converted the TRTLLM_ENABLE_PDL environment variable into an explicit enable_pdl function parameter across relevant CuteDSL MoE kernels and wrappers, improving configurability.
  • Added Compute Capability Decorator: Applied a supported_compute_capability decorator to the CuteDSL MoE functions and wrapper, ensuring they are only used on compatible GPU architectures (SM100, SM103).
  • Comprehensive Tactic Testing: Introduced a new unit test class TestAllValidTactics to verify the numerical accuracy of all valid autotuning tactics for CuteDSL MoE.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/bench_moe_deepseek.py
    • Updated the is_sm100_family function's docstring to remove SM110 from the list of supported architectures.
  • benchmarks/routines/flashinfer_benchmark_utils.py
    • Added cute_dsl_fp4_block_scale_moe to the list of moe_routines and defined its supported compute capabilities.
  • benchmarks/routines/moe.py
    • Implemented testCuteDslFp4BlockScaleMoe for benchmarking, including helper functions _interleave_linear_and_gate and _create_cute_dsl_moe_test_data for setting up test data.
  • benchmarks/samples/sample_testlist.txt
    • Added several benchmark commands for cute_dsl_fp4_block_scale_moe with various configurations, including autotuning and expert parallelism.
  • flashinfer/fused_moe/cute_dsl/blackwell/init.py
    • Removed imports for Sm100BlockScaledContiguousGroupedGemmKernel, Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel, and TRTLLM_ENABLE_PDL.
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
    • Removed TRTLLM_ENABLE_PDL import, added enable_pdl as an __init__ parameter and used it in kernel launch, and updated the ValueError message for compute capability check.
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py
    • Removed the entire file, as the kernel is no longer used.
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
    • Removed TRTLLM_ENABLE_PDL import, added enable_pdl as an __init__ parameter and used it in kernel launch.
  • flashinfer/fused_moe/cute_dsl/blackwell/utils.py
    • Removed TRTLLM_ENABLE_PDL variable and the os module import.
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
    • Removed flashinfer_api decorator, added enable_pdl as a parameter to _get_compiled_gather_kernel and blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4, and updated the ValueError message for compute capability check.
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
    • Removed flashinfer_api decorator, added enable_pdl as a parameter to _get_compiled_finalize_kernel and blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4, and updated the ValueError message for compute capability check.
  • flashinfer/fused_moe/cute_dsl/fused_moe.py
    • Imported supported_compute_capability, added enable_pdl parameter to _moe_core_impl, CuteDslMoEWrapper.__init__, CuteDslMoEWrapper._forward_with_tactic, and cute_dsl_fused_moe_nvfp4. Also added supported_compute_capability decorator to CuteDslMoEWrapper.__init__ and cute_dsl_fused_moe_nvfp4, and adjusted the use_prealloc logic in _forward_with_tactic.
  • flashinfer/fused_moe/cute_dsl/tuner.py
    • Updated comments and logic in get_gemm2_valid_tactics to clarify MMA tiler and cluster shape for the finalize kernel, restricted tile_size to 128 for GEMM1 tactics due to a bug, updated the total tactics count, and added can_implement checks within get_valid_tactics to filter tactics based on problem dimensions.
  • flashinfer/jit/moe_utils.py
    • Modified gen_moe_utils_module to restrict nvcc_flags to supported_major_versions=[10].
  • tests/moe/test_cute_dsl_fused_moe.py
    • Updated the is_sm100_family docstring to remove SM110, and added a new test class TestAllValidTactics to verify the accuracy of all valid autotuning tactics.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@flashinfer-bot
Copy link
Collaborator

GitLab MR !411 has been created, and the CI pipeline #46012789 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request provides a comprehensive cleanup and enhancement for the CuteDSL NVFP4 MoE implementation. The changes are well-structured and improve the codebase in several key areas. Notably, the refactoring to remove environment variable dependencies in favor of function parameters makes the kernels more modular and easier to use. The autotuning strategy has been significantly improved by dynamically filtering valid tactics based on problem size, which increases robustness. The removal of unused code and the addition of thorough unit tests for all valid tactics are excellent for maintainability and correctness. The new benchmarks and compute capability decorators are also valuable additions. Overall, this is a high-quality contribution that significantly refines the feature.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
flashinfer/fused_moe/cute_dsl/tuner.py (2)

299-308: ⚠️ Potential issue | 🟠 Major

Validate GEMM2 with the configured output dtype, not hard-coded BF16.

CuteDslFusedMoENvfp4Runner exposes output_dtype, but get_valid_tactics() always asks the finalize kernel about cutlass.BFloat16. On a non-BF16 runner this can admit or reject tactics against the wrong alignment rules. Either map self.output_dtype here or fail fast when only BF16 is supported.

Also applies to: 361-367, 397-403

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/tuner.py` around lines 299 - 308,
get_valid_tactics() currently validates GEMM2 using a hard-coded
cutlass.BFloat16 type which mis-validates tactics when
CuteDslFusedMoENvfp4Runner was constructed with a different output_dtype; update
get_valid_tactics() to map/convert self.output_dtype to the correct cutlass
dtype (or raise a clear error if only BF16 is supported) before querying the
finalize kernel and GEMM2 validation so the alignment/packing rules match the
runner’s configured output_dtype (references: CuteDslFusedMoENvfp4Runner,
get_valid_tactics, self.output_dtype, GEMM2, finalize kernel).

299-308: ⚠️ Potential issue | 🟡 Minor

Include enable_pdl in the runner hash.

enable_pdl changes the execution path and the downstream JIT cache key, but __hash__() still ignores it. That lets autotune results learned with PDL on leak into a runner created with PDL off, and vice versa.

Suggested fix
     def __hash__(self):
         return hash(
             (
                 self.num_experts,
                 self.top_k,
                 self.num_local_experts,
                 self.local_expert_offset,
                 self.use_fused_finalize,
                 self.output_dtype,
+                self.enable_pdl,
             )
         )

Also applies to: 310-320

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/tuner.py` around lines 299 - 308, The runner's
__hash__ currently omits enable_pdl so autotune/JIT cache keys mix PDL and
non-PDL runs; update the class's __hash__ implementation to include
self.enable_pdl alongside the other fields used (forward_impl, num_experts,
top_k, num_local_experts, local_expert_offset, use_fused_finalize, output_dtype)
so the hash reflects the PDL flag; do the same change for the other similar
class/section referenced (the block around where enable_pdl is set at lines
~310-320) so both runner hashes incorporate enable_pdl.
flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py (1)

162-191: ⚠️ Potential issue | 🟠 Major

The finalize-kernel cache key is still too weak for cute.compile().

This helper now separates PDL/non-PDL kernels, but it still ignores the typed-pointer parts of the compiled signature. Reusing a kernel first compiled for one combination of ab_dtype/sf_dtype/out_dtype or one token_final_scales.dtype with a different combination will dispatch the wrong wrapper specialization.

Also applies to: 203-211

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 162 - 191, The cache key for _get_compiled_finalize_kernel must
include the element dtypes of the typed pointers and token scales so compiled
wrappers aren’t reused across incompatible pointer/scale types; update the key
construction to incorporate the element types for
a_ptr/b_ptr/c_ptr/a_sf_ptr/b_sf_ptr/alpha_ptr and the dtype of token_scales_ptr
(and keep the existing enable_pdl flag), and apply the same change to the
companion key-building site that creates the finalize-kernel key for the other
code path so both PDL and non-PDL kernels are keyed by pointer element types and
token_scales dtype.
tests/moe/test_cute_dsl_fused_moe.py (1)

38-56: ⚠️ Potential issue | 🟠 Major

Use flashinfer.utils architecture helpers for skip gating.

The custom is_sm100_family() check should use flashinfer.utils.get_compute_capability() instead of direct torch.cuda calls. Additionally, the skip condition mentions SM110 in the reason but only checks for major == 10; SM110 is major == 11. Either correct the reason to exclude SM110, or extend the check to include both SM100 (major 10) and SM110 (major 11) families:

from flashinfer.utils import get_compute_capability

def is_sm100_family():
    """Check for SM100/SM110 family (Blackwell)."""
    if not torch.cuda.is_available():
        return False
    major, _ = get_compute_capability(torch.device("cuda"))
    return major in [10, 11]  # SM100/SM103 (10) or SM110 (11)

Or use the composed helpers: is_sm100a_supported() or is_sm110a_supported().

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_cute_dsl_fused_moe.py` around lines 38 - 56, Replace the
custom torch.cuda-based is_sm100_family() with a compute-capability check using
flashinfer.utils (e.g., call get_compute_capability(torch.device("cuda")) to
obtain major) or use the composed helpers is_sm100a_supported() or
is_sm110a_supported(); update the function docstring to say SM100/SM110
(Blackwell) and make the return condition accept major in [10, 11], and update
the sm100_required pytest reason string to match (or restrict it to only SM100
if you intentionally keep major == 10).
🧹 Nitpick comments (2)
flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py (1)

395-443: Document the new enable_pdl knob.

enable_pdl is now part of the public constructor, but the constructor docstring still omits it. A short :param enable_pdl: entry would make the launch behavior easier to discover for callers.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`
around lines 395 - 443, The constructor __init__ now accepts the enable_pdl
parameter but the docstring is missing documentation for it; update the __init__
docstring to add a short ":param enable_pdl: bool" description (and optionally a
brief note in the top descriptive section) explaining what enabling/disabling
PDL does and its default (True) so callers can discover this knob; reference the
enable_pdl attribute set on self and keep wording consistent with existing param
entries like vectorized_f32 and topk.
flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py (1)

357-387: Document enable_pdl in the constructor docs.

This is now a public configuration parameter, but the constructor docstring still stops before describing it. Adding a :param enable_pdl: entry would make the API surface clearer.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 357 - 387, The constructor docstring for __init__ is missing
documentation for the new public parameter enable_pdl; update the docstring to
add a ":param enable_pdl: Boolean, True to enable PDL
(page/packet/precision-dependent logic) in the kernel." and a corresponding
":type enable_pdl: bool" line, placing it alongside the other param entries
(near mma_tiler_mn, cluster_shape_mn, use_blkred, raster_along_m) so callers of
the class and generated docs clearly see the new configuration option referenced
by the enable_pdl attribute.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_moe_deepseek.py`:
- Around line 71-74: The runtime guard message in main() is out of sync with the
updated docstring: change the error/exit message emitted by main() (the SM
check) to reference SM100/SM103 (Blackwell) instead of SM110, so the runtime
error text matches the docstring and tells users that CuteDSL MoE NVFP4 kernels
are optimized for SM100/SM103 only; update the string used in main()
accordingly.

In `@benchmarks/routines/moe.py`:
- Around line 1176-1179: The test data creator _create_cute_dsl_moe_test_data
currently always stages activations/weights as bfloat16 (variables like x_bf16
and w_bf16) then quantizes, causing reported args.input_dtype/args.weight_dtype
to be incorrect; change the tensor constructions to use args.input_dtype for
input activations and args.weight_dtype for weights (and any similarly staged
tensors) so the initial dtype matches the CLI metadata and the subsequent
quantization path remains unchanged; update all occurrences where tensors are
hardcoded to torch.bfloat16 before quantization so metadata/bandwidth reporting
reflects the actual requested dtypes.

In `@flashinfer/fused_moe/cute_dsl/tuner.py`:
- Around line 418-429: The current code appends DEFAULT_MOE_TACTIC when
valid_tactics is empty even though every candidate was rejected by
can_implement(), which hides the real incompatibility; instead, remove the
fallback to DEFAULT_MOE_TACTIC and surface a clear error: when valid_tactics is
empty after running can_implement() checks (the block that currently references
valid_tactics and DEFAULT_MOE_TACTIC), log a detailed warning including the
problem dims (num_tokens, hidden_size, intermediate_size, num_local_experts,
self.top_k) and raise an exception (e.g., ValueError or RuntimeError) explaining
that no tactics can implement the problem and include suggestions or the
rejected candidate list if available so callers get an actionable failure
instead of deferring to kernel launch.

In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 995-1010: Add a public wrapper method on CuteDslMoEWrapper named
get_valid_tactics_for_inputs(self, inputs: List[torch.Tensor]) that delegates to
the internal runner by calling self._runner.get_valid_tactics(inputs, None) and
returns the filtered tactics; update tests to call
moe.get_valid_tactics_for_inputs(inputs) instead of touching the private
_runner. Ensure the method signature accepts a list of torch.Tensor and simply
forwards the call/result without changing semantics.

---

Outside diff comments:
In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 162-191: The cache key for _get_compiled_finalize_kernel must
include the element dtypes of the typed pointers and token scales so compiled
wrappers aren’t reused across incompatible pointer/scale types; update the key
construction to incorporate the element types for
a_ptr/b_ptr/c_ptr/a_sf_ptr/b_sf_ptr/alpha_ptr and the dtype of token_scales_ptr
(and keep the existing enable_pdl flag), and apply the same change to the
companion key-building site that creates the finalize-kernel key for the other
code path so both PDL and non-PDL kernels are keyed by pointer element types and
token_scales dtype.

In `@flashinfer/fused_moe/cute_dsl/tuner.py`:
- Around line 299-308: get_valid_tactics() currently validates GEMM2 using a
hard-coded cutlass.BFloat16 type which mis-validates tactics when
CuteDslFusedMoENvfp4Runner was constructed with a different output_dtype; update
get_valid_tactics() to map/convert self.output_dtype to the correct cutlass
dtype (or raise a clear error if only BF16 is supported) before querying the
finalize kernel and GEMM2 validation so the alignment/packing rules match the
runner’s configured output_dtype (references: CuteDslFusedMoENvfp4Runner,
get_valid_tactics, self.output_dtype, GEMM2, finalize kernel).
- Around line 299-308: The runner's __hash__ currently omits enable_pdl so
autotune/JIT cache keys mix PDL and non-PDL runs; update the class's __hash__
implementation to include self.enable_pdl alongside the other fields used
(forward_impl, num_experts, top_k, num_local_experts, local_expert_offset,
use_fused_finalize, output_dtype) so the hash reflects the PDL flag; do the same
change for the other similar class/section referenced (the block around where
enable_pdl is set at lines ~310-320) so both runner hashes incorporate
enable_pdl.

In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 38-56: Replace the custom torch.cuda-based is_sm100_family() with
a compute-capability check using flashinfer.utils (e.g., call
get_compute_capability(torch.device("cuda")) to obtain major) or use the
composed helpers is_sm100a_supported() or is_sm110a_supported(); update the
function docstring to say SM100/SM110 (Blackwell) and make the return condition
accept major in [10, 11], and update the sm100_required pytest reason string to
match (or restrict it to only SM100 if you intentionally keep major == 10).

---

Nitpick comments:
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`:
- Around line 395-443: The constructor __init__ now accepts the enable_pdl
parameter but the docstring is missing documentation for it; update the __init__
docstring to add a short ":param enable_pdl: bool" description (and optionally a
brief note in the top descriptive section) explaining what enabling/disabling
PDL does and its default (True) so callers can discover this knob; reference the
enable_pdl attribute set on self and keep wording consistent with existing param
entries like vectorized_f32 and topk.

In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 357-387: The constructor docstring for __init__ is missing
documentation for the new public parameter enable_pdl; update the docstring to
add a ":param enable_pdl: Boolean, True to enable PDL
(page/packet/precision-dependent logic) in the kernel." and a corresponding
":type enable_pdl: bool" line, placing it alongside the other param entries
(near mma_tiler_mn, cluster_shape_mn, use_blkred, raster_along_m) so callers of
the class and generated docs clearly see the new configuration option referenced
by the enable_pdl attribute.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4bb78545-d18c-43a4-81f0-b9474ed91c80

📥 Commits

Reviewing files that changed from the base of the PR and between e3aa638 and f450151.

📒 Files selected for processing (18)
  • benchmarks/bench_moe_deepseek.py
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • benchmarks/routines/moe.py
  • benchmarks/samples/sample_testlist.txt
  • flashinfer/fused_moe/cute_dsl/blackwell/__init__.py
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py
  • flashinfer/fused_moe/cute_dsl/blackwell/utils.py
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm.py
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py
  • flashinfer/fused_moe/cute_dsl/fused_moe.py
  • flashinfer/fused_moe/cute_dsl/tuner.py
  • flashinfer/jit/moe_utils.py
  • tests/moe/test_cute_dsl_fused_moe.py
💤 Files with no reviewable changes (5)
  • flashinfer/fused_moe/cute_dsl/blackwell/init.py
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm.py
  • flashinfer/fused_moe/cute_dsl/blackwell/utils.py
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_swiglu_fusion.py
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm.py

Comment on lines +1176 to +1179
# Input activations
x_bf16 = (
torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device=device) / 10
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

This benchmark can report the wrong input/weight dtypes.

The new CuteDSL route records args.input_dtype and args.weight_dtype, but _create_cute_dsl_moe_test_data() always stages BF16 activations and weights before quantization. Any non-default CLI dtype will benchmark BF16 and then publish misleading metadata/bandwidth numbers.

Also applies to: 1192-1199, 1223-1230, 1292-1293, 1317-1326

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/moe.py` around lines 1176 - 1179, The test data creator
_create_cute_dsl_moe_test_data currently always stages activations/weights as
bfloat16 (variables like x_bf16 and w_bf16) then quantizes, causing reported
args.input_dtype/args.weight_dtype to be incorrect; change the tensor
constructions to use args.input_dtype for input activations and
args.weight_dtype for weights (and any similarly staged tensors) so the initial
dtype matches the CLI metadata and the subsequent quantization path remains
unchanged; update all occurrences where tensors are hardcoded to torch.bfloat16
before quantization so metadata/bandwidth reporting reflects the actual
requested dtypes.

Comment on lines +418 to +429
if not valid_tactics:
logger.warning(
"No valid tactics found for problem dims "
"(tokens=%d, hidden=%d, intermediate=%d, experts=%d, top_k=%d). "
"Falling back to default tactic.",
num_tokens,
hidden_size,
intermediate_size,
num_local_experts,
self.top_k,
)
valid_tactics = [DEFAULT_MOE_TACTIC]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't fall back to a tactic that can_implement() already rejected.

If valid_tactics is empty, every candidate failed the actual shape/alignment checks. Returning DEFAULT_MOE_TACTIC here just defers the failure to kernel launch with a much less actionable error.

Suggested fix
-        if not valid_tactics:
-            logger.warning(
-                "No valid tactics found for problem dims "
-                "(tokens=%d, hidden=%d, intermediate=%d, experts=%d, top_k=%d). "
-                "Falling back to default tactic.",
-                num_tokens,
-                hidden_size,
-                intermediate_size,
-                num_local_experts,
-                self.top_k,
-            )
-            valid_tactics = [DEFAULT_MOE_TACTIC]
+        if not valid_tactics:
+            raise ValueError(
+                "No valid CuteDSL MoE tactics for the current problem dimensions."
+            )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if not valid_tactics:
logger.warning(
"No valid tactics found for problem dims "
"(tokens=%d, hidden=%d, intermediate=%d, experts=%d, top_k=%d). "
"Falling back to default tactic.",
num_tokens,
hidden_size,
intermediate_size,
num_local_experts,
self.top_k,
)
valid_tactics = [DEFAULT_MOE_TACTIC]
if not valid_tactics:
raise ValueError(
"No valid CuteDSL MoE tactics for the current problem dimensions."
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/tuner.py` around lines 418 - 429, The current
code appends DEFAULT_MOE_TACTIC when valid_tactics is empty even though every
candidate was rejected by can_implement(), which hides the real incompatibility;
instead, remove the fallback to DEFAULT_MOE_TACTIC and surface a clear error:
when valid_tactics is empty after running can_implement() checks (the block that
currently references valid_tactics and DEFAULT_MOE_TACTIC), log a detailed
warning including the problem dims (num_tokens, hidden_size, intermediate_size,
num_local_experts, self.top_k) and raise an exception (e.g., ValueError or
RuntimeError) explaining that no tactics can implement the problem and include
suggestions or the rejected candidate list if available so callers get an
actionable failure instead of deferring to kernel launch.

Comment on lines +995 to +1010
# Get the filtered list of valid tactics for this problem size
inputs = [
tensors["x"],
tensors["x_sf"],
tensors["token_selected_experts"],
tensors["token_final_scales"],
tensors["w1_weight"],
tensors["w1_weight_sf"],
tensors["w1_alpha"],
tensors["fc2_input_scale"],
tensors["w2_weight"],
tensors["w2_weight_sf"],
tensors["w2_alpha"],
]
valid_tactics = moe._runner.get_valid_tactics(inputs, None)
assert len(valid_tactics) > 0, "No valid tactics found"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify whether a public tactic-query API exists and where private access is used.
rg -nP --type=py 'class\s+CuteDslMoEWrapper\b|def\s+get_valid_tactics\s*\('
rg -nP --type=py '\._runner\.get_valid_tactics\s*\('

Repository: flashinfer-ai/flashinfer

Length of output: 1718


🏁 Script executed:

#!/bin/bash
# Examine the CuteDslMoEWrapper class and get_valid_tactics method
sed -n '301,650p' flashinfer/fused_moe/cute_dsl/fused_moe.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 15680


🏁 Script executed:

#!/bin/bash
# Also check what _runner is and its type
rg -A 5 'self\._runner\s*=' flashinfer/fused_moe/cute_dsl/fused_moe.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 339


🏁 Script executed:

#!/bin/bash
# Check the test file to see how moe object is instantiated
sed -n '990,1020p' tests/moe/test_cute_dsl_fused_moe.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 1398


🏁 Script executed:

#!/bin/bash
# Find CuteDslFusedMoENvfp4Runner and its get_valid_tactics method
rg -B 3 -A 10 'class CuteDslFusedMoENvfp4Runner' flashinfer/fused_moe/cute_dsl/tuner.py

Repository: flashinfer-ai/flashinfer

Length of output: 507


🏁 Script executed:

#!/bin/bash
# Check the get_valid_tactics implementation in tuner.py
sed -n '320,350p' flashinfer/fused_moe/cute_dsl/tuner.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 1384


🏁 Script executed:

#!/bin/bash
# Check if there are other accesses to _runner.get_valid_tactics in the codebase
rg -n '\._runner\.get_valid_tactics' --type=py

Repository: flashinfer-ai/flashinfer

Length of output: 177


🏁 Script executed:

#!/bin/bash
# Check what ALL_MOE_TACTICS contains
rg -B 2 -A 5 'ALL_MOE_TACTICS\s*=' flashinfer/fused_moe/cute_dsl/

Repository: flashinfer-ai/flashinfer

Length of output: 656


🏁 Script executed:

#!/bin/bash
# Check the full context around line 612 to see if get_valid_tactics should be updated
sed -n '300,320p' flashinfer/fused_moe/cute_dsl/fused_moe.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 1053


Add a public method to CuteDslMoEWrapper for getting filtered tactics instead of accessing private _runner.

The test accesses moe._runner.get_valid_tactics(inputs, None) to retrieve tactics filtered by the actual problem dimensions. This couples the test to private internals. Create a public method like get_valid_tactics_for_inputs(self, inputs: List[torch.Tensor]) in CuteDslMoEWrapper that delegates to self._runner.get_valid_tactics(), allowing tests to retrieve properly filtered tactics through the public API.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_cute_dsl_fused_moe.py` around lines 995 - 1010, Add a public
wrapper method on CuteDslMoEWrapper named get_valid_tactics_for_inputs(self,
inputs: List[torch.Tensor]) that delegates to the internal runner by calling
self._runner.get_valid_tactics(inputs, None) and returns the filtered tactics;
update tests to call moe.get_valid_tactics_for_inputs(inputs) instead of
touching the private _runner. Ensure the method signature accepts a list of
torch.Tensor and simply forwards the call/result without changing semantics.

@nv-yunzheq
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !411 has been updated with latest changes, and the CI pipeline #46013906 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46013906: 7/20 passed

@yongwww
Copy link
Member

yongwww commented Mar 13, 2026

I cancelled the pr test because the ci won't pass before #2781 lands, and please re-trigger the test after that pr get merged

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
benchmarks/routines/moe.py (1)

1155-1155: Consider adding an error message to the assertion.

A bare assert without a message makes debugging harder when the condition fails.

💡 Suggested improvement
-    assert sizes[dim] % (group_size * 2) == 0
+    assert sizes[dim] % (group_size * 2) == 0, (
+        f"Dimension {dim} size ({sizes[dim]}) must be divisible by {group_size * 2}"
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/moe.py` at line 1155, The assertion "assert sizes[dim] %
(group_size * 2) == 0" is bare and should include a helpful message; update it
to either an assert with a formatted message (e.g. assert sizes[dim] %
(group_size * 2) == 0, f\"sizes[{dim}] ({sizes[dim]}) must be divisible by
2*group_size ({group_size*2})\") or raise a ValueError with the same formatted
message to make failures clear; modify the assertion in
benchmarks/routines/moe.py where sizes, dim, and group_size are used so the
runtime error explains the violated condition.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/routines/moe.py`:
- Around line 1402-1415: This routine's autotune block ignores a provided
--autotune_cache; update the autotune invocation to accept and pass the cache
path from args (check getattr(args, "autotune_cache", None) and only pass it
when present) so the autotune context manager is called with the cache option;
specifically, modify the with autotune(True): line to something like with
autotune(True, cache_path=args.autotune_cache) (or the correct parameter name
used by the autotune context manager), leaving run_cute_dsl_moe and
autotune_args unchanged.

---

Nitpick comments:
In `@benchmarks/routines/moe.py`:
- Line 1155: The assertion "assert sizes[dim] % (group_size * 2) == 0" is bare
and should include a helpful message; update it to either an assert with a
formatted message (e.g. assert sizes[dim] % (group_size * 2) == 0,
f\"sizes[{dim}] ({sizes[dim]}) must be divisible by 2*group_size
({group_size*2})\") or raise a ValueError with the same formatted message to
make failures clear; modify the assertion in benchmarks/routines/moe.py where
sizes, dim, and group_size are used so the runtime error explains the violated
condition.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ee3c9014-4684-4e17-a8aa-beb50311e03e

📥 Commits

Reviewing files that changed from the base of the PR and between fd50e23 and 71042bf.

📒 Files selected for processing (1)
  • benchmarks/routines/moe.py

Comment on lines +1402 to +1415
if getattr(args, "autotune", False):
warmup_iters = (
args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10
)
backend = "cute-dsl_autotune"
if args.verbose >= 1:
print(f"[INFO] Autotune warmup for CuteDSL NVFP4 MoE: {warmup_iters} iters")
autotune_args = tuple(
t.clone() if isinstance(t, torch.Tensor) else t for t in input_args
)
with autotune(True):
for _ in range(warmup_iters):
run_cute_dsl_moe(*autotune_args)
del autotune_args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Autotune cache path not supported.

Unlike other test functions (testTrtllmFp4BlockScaleMoe at line 671, testCutlassFusedMoe at line 1068), this function doesn't support the autotune_cache argument. Users who pass --autotune_cache will have their cache ignored for this routine.

🔧 Proposed fix to add cache support
+    cache_path = getattr(args, "autotune_cache", None)
+
     # Optional autotune warmup.
     # Clone input_args so autotune tactic exploration doesn't corrupt the
     # original tensors used by the subsequent benchmark.
     if getattr(args, "autotune", False):
         warmup_iters = (
             args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10
         )
         backend = "cute-dsl_autotune"
         if args.verbose >= 1:
             print(f"[INFO] Autotune warmup for CuteDSL NVFP4 MoE: {warmup_iters} iters")
         autotune_args = tuple(
             t.clone() if isinstance(t, torch.Tensor) else t for t in input_args
         )
-        with autotune(True):
+        with autotune(True, cache=cache_path):
             for _ in range(warmup_iters):
                 run_cute_dsl_moe(*autotune_args)
         del autotune_args
+    elif cache_path:
+        with autotune(False, cache=cache_path):
+            pass
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/moe.py` around lines 1402 - 1415, This routine's autotune
block ignores a provided --autotune_cache; update the autotune invocation to
accept and pass the cache path from args (check getattr(args, "autotune_cache",
None) and only pass it when present) so the autotune context manager is called
with the cache option; specifically, modify the with autotune(True): line to
something like with autotune(True, cache_path=args.autotune_cache) (or the
correct parameter name used by the autotune context manager), leaving
run_cute_dsl_moe and autotune_args unchanged.

Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved. What LGTM:

  • Benchmark changes look correct
  • Removal of unused less-fused kernel variants -- vanilla grouped gemm and grouped gemm + swiglu -- make sense to me

@aleozlx aleozlx enabled auto-merge (squash) March 17, 2026 23:18
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/fused_moe/cute_dsl/fused_moe.py`:
- Around line 342-343: Replace the `@supported_compute_capability` decorator on
CuteDslMoEWrapper.__init__ with `@backend_requirement` so the public API exposes
both is_compute_capability_supported(cc) and is_backend_supported(); to do this,
create a separate small check function (e.g., cute_dsl_moe_compute_check)
decorated with `@supported_compute_capability`([100, 103]) that returns the
compute-capability check, then apply `@backend_requirement`({},
common_check=cute_dsl_moe_compute_check) and `@flashinfer_api` to
CuteDslMoEWrapper.__init__ (mirroring the nvfp4_kv_dequantize pattern) so
backend discovery and compute-capability checks are both available.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a24feca8-7413-4278-98c4-4ade10a80466

📥 Commits

Reviewing files that changed from the base of the PR and between 71042bf and 476899a.

📒 Files selected for processing (6)
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
  • flashinfer/fused_moe/cute_dsl/fused_moe.py
  • flashinfer/jit/moe_utils.py
✅ Files skipped from review due to trivial changes (1)
  • flashinfer/jit/moe_utils.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Comment on lines +342 to 343
@supported_compute_capability([100, 103])
@flashinfer_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Decorator definitions ==="
rg -n -C5 'def supported_compute_capability\b|def backend_requirement\b' flashinfer

echo
echo "=== Backend metadata helpers exposed in flashinfer/ ==="
rg -n -C2 'is_compute_capability_supported|is_backend_supported' flashinfer -g '!**/tests/**'

echo
echo "=== Class-based API decoration patterns ==="
rg -n -C3 'class .*Wrapper|@backend_requirement|@supported_compute_capability' flashinfer -g '!**/tests/**'

Repository: flashinfer-ai/flashinfer

Length of output: 40762


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== CuteDslMoEWrapper class definition and methods ==="
sed -n '302,700p' flashinfer/fused_moe/cute_dsl/fused_moe.py | head -150

echo
echo "=== Check if CuteDslMoEWrapper or __init__ expose helper methods ==="
rg -n -A30 'class CuteDslMoEWrapper' flashinfer/fused_moe/cute_dsl/fused_moe.py | grep -E 'is_compute_capability_supported|is_backend_supported|@property|def ' | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 6159


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Check if CuteDslMoEWrapper has is_compute_capability_supported or is_backend_supported as instance/class methods ==="
grep -n 'is_compute_capability_supported\|is_backend_supported' flashinfer/fused_moe/cute_dsl/fused_moe.py

echo
echo "=== Check how other Wrapper classes are decorated (if at all) ==="
rg -n 'class.*Wrapper' flashinfer -g '!**/tests/**' | head -10

echo
echo "=== Example: Check BatchDecodeWithPagedKVCacheWrapper for comparison ==="
sed -n '593,650p' flashinfer/decode.py

echo
echo "=== Check if functions (not classes) are the primary target for compute capability decorators ==="
rg -B2 '@backend_requirement' flashinfer -g '!**/tests/**' | grep -E '(def |@backend_requirement)' | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 188


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Example: Check BatchDecodeWithPagedKVCacheWrapper for comparison ==="
sed -n '593,650p' flashinfer/decode.py

echo
echo "=== Check if functions (not classes) are the primary target for compute capability decorators ==="
rg -B2 '@backend_requirement' flashinfer -g '!**/tests/**' | grep -E '(def |@backend_requirement)' | head -30

echo
echo "=== Check CuteDslMoEWrapper.run method ==="
rg -n -A10 'def run' flashinfer/fused_moe/cute_dsl/fused_moe.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 4388


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Check if `@flashinfer_api` handles the compute capability check for methods ==="
rg -n -B5 -A15 'def flashinfer_api' flashinfer/utils.py | head -50

echo
echo "=== Verify if CuteDslMoEWrapper.__init__ actually has the helper methods attached ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '/repo')

# Try to import and check if __init__ has the helper
try:
    from flashinfer.fused_moe.cute_dsl.fused_moe import CuteDslMoEWrapper
    init_method = CuteDslMoEWrapper.__init__
    
    print("Checking __init__ attributes:")
    print(f"  has _supported_ccs: {hasattr(init_method, '_supported_ccs')}")
    print(f"  has is_compute_capability_supported: {hasattr(init_method, 'is_compute_capability_supported')}")
    print(f"  has is_backend_supported: {hasattr(init_method, 'is_backend_supported')}")
    
    print("\nChecking CuteDslMoEWrapper class attributes:")
    print(f"  has _supported_ccs: {hasattr(CuteDslMoEWrapper, '_supported_ccs')}")
    print(f"  has is_compute_capability_supported: {hasattr(CuteDslMoEWrapper, 'is_compute_capability_supported')}")
    print(f"  has is_backend_supported: {hasattr(CuteDslMoEWrapper, 'is_backend_supported')}")
    
except Exception as e:
    print(f"Error: {e}")
    import traceback
    traceback.print_exc()
EOF

Repository: flashinfer-ai/flashinfer

Length of output: 149


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Search for flashinfer_api decorator definition ==="
rg -n 'flashinfer_api' flashinfer/utils.py | head -5

echo
echo "=== Get decorator definition ==="
rg -n -A20 'def flashinfer_api' flashinfer

echo
echo "=== Verify class instance does NOT have helper methods at runtime ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '/repo')

try:
    # Check if __init__ method itself has the helpers
    from flashinfer.utils import supported_compute_capability
    import inspect
    
    `@supported_compute_capability`([100, 103])
    def example_func():
        pass
    
    print("Function decorated with `@supported_compute_capability`:")
    print(f"  has _supported_ccs: {hasattr(example_func, '_supported_ccs')}")
    print(f"  has is_compute_capability_supported: {hasattr(example_func, 'is_compute_capability_supported')}")
    
    # Now check a class method
    class ExampleClass:
        `@supported_compute_capability`([100, 103])
        def __init__(self):
            pass
    
    print("\nClass method decorated with `@supported_compute_capability`:")
    print(f"  __init__ has _supported_ccs: {hasattr(ExampleClass.__init__, '_supported_ccs')}")
    print(f"  __init__ has is_compute_capability_supported: {hasattr(ExampleClass.__init__, 'is_compute_capability_supported')}")
    
    instance = ExampleClass()
    print(f"  instance has is_compute_capability_supported: {hasattr(instance, 'is_compute_capability_supported')}")
    
except Exception as e:
    print(f"Error: {e}")
    import traceback
    traceback.print_exc()
EOF

Repository: flashinfer-ai/flashinfer

Length of output: 122


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Find flashinfer_api definition ==="
fd flashinfer_api flashinfer --type f

echo
echo "=== Search more broadly for flashinfer_api ==="
rg 'def flashinfer_api' flashinfer

echo
echo "=== Check what `@flashinfer_api` is in the imports/exports ==="
rg -n 'flashinfer_api' flashinfer/__init__.py

echo
echo "=== Verify actual decorator stacking on CuteDslMoEWrapper.__init__ ==="
python3 -c "
import sys
sys.path.insert(0, '/repo')
from flashinfer.fused_moe.cute_dsl.fused_moe import CuteDslMoEWrapper

# Check what the __init__ actually is
init_fn = CuteDslMoEWrapper.__init__
print('Type of __init__:', type(init_fn))
print('__init__ has _supported_ccs:', hasattr(init_fn, '_supported_ccs'))
print('__init__ has is_compute_capability_supported:', hasattr(init_fn, 'is_compute_capability_supported'))
print('__init__ attributes:', [a for a in dir(init_fn) if not a.startswith('_')])
" 2>&1 || echo "(import failed)"

echo
echo "=== Compare with a decorated standalone function ==="
rg -A5 'def cute_dsl_fused_moe_nvfp4' flashinfer/fused_moe/cute_dsl/fused_moe.py

echo
echo "=== Check if standalone functions use backend_requirement or supported_compute_capability directly ==="
rg -B3 'def cute_dsl_fused_moe_nvfp4' flashinfer/fused_moe/cute_dsl/fused_moe.py

Repository: flashinfer-ai/flashinfer

Length of output: 374


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Check flashinfer_api definition in api_logging.py ==="
rg -n -A20 'def flashinfer_api' flashinfer/api_logging.py

echo
echo "=== Check how `@backend_requirement` and `@flashinfer_api` work together ==="
rg -B5 -A10 '@backend_requirement.*\n.*@flashinfer_api' flashinfer/quantization/fp4_quantization.py | head -30

echo
echo "=== Review the decorated function pattern to confirm backend_requirement is the standard ==="
sed -n '1269,1280p' flashinfer/quantization/fp4_quantization.py

echo
echo "=== Now check CuteDslMoEWrapper pattern more carefully ==="
sed -n '342,365p' flashinfer/fused_moe/cute_dsl/fused_moe.py

Repository: flashinfer-ai/flashinfer

Length of output: 1621


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Check decorator pattern for standalone function ==="
sed -n '1269,1276p' flashinfer/quantization/fp4_quantization.py

echo
echo "=== Check CuteDslMoEWrapper.__init__ decoration ==="
sed -n '342,344p' flashinfer/fused_moe/cute_dsl/fused_moe.py

echo
echo "=== Verify the key difference: backend_requirement on function vs supported_compute_capability on method ==="
rg -B2 'def cute_dsl_fused_moe_nvfp4' flashinfer/fused_moe/cute_dsl/fused_moe.py | head -10

echo
echo "=== Check if there are ANY Wrapper classes with backend_requirement on __init__ ==="
rg -B3 'class.*Wrapper' flashinfer -g '!**/tests/**' -A50 | rg '@backend_requirement' | head -5

echo
echo "=== Confirm: standalone functions use `@backend_requirement`, not class methods ==="
rg '@backend_requirement' flashinfer -B3 -g '!**/tests/**' | grep -E '(def |@backend)' | head -15

Repository: flashinfer-ai/flashinfer

Length of output: 824


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Check flashinfer_api full implementation ==="
sed -n '1420,1500p' flashinfer/api_logging.py

echo
echo "=== Does flashinfer_api handle backend requirement checking ==="
rg -A50 'def flashinfer_api' flashinfer/api_logging.py | grep -E '(is_backend|is_compute|backend_requirement)' | head -10

echo
echo "=== Examine if `@supported_compute_capability` on standalone functions provides is_backend_supported ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '/repo')

from flashinfer.fused_moe.cute_dsl.fused_moe import cute_dsl_fused_moe_nvfp4

print("cute_dsl_fused_moe_nvfp4 attributes:")
print(f"  has _supported_ccs: {hasattr(cute_dsl_fused_moe_nvfp4, '_supported_ccs')}")
print(f"  has is_compute_capability_supported: {hasattr(cute_dsl_fused_moe_nvfp4, 'is_compute_capability_supported')}")
print(f"  has is_backend_supported: {hasattr(cute_dsl_fused_moe_nvfp4, 'is_backend_supported')}")

if hasattr(cute_dsl_fused_moe_nvfp4, 'is_compute_capability_supported'):
    print(f"  is_compute_capability_supported(100): {cute_dsl_fused_moe_nvfp4.is_compute_capability_supported(100)}")
EOF

Repository: flashinfer-ai/flashinfer

Length of output: 3899


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Check if cute_dsl_fused_moe_nvfp4 has is_backend_supported ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '/repo')

from flashinfer.fused_moe.cute_dsl.fused_moe import cute_dsl_fused_moe_nvfp4

print("cute_dsl_fused_moe_nvfp4 attributes:")
print(f"  has _supported_ccs: {hasattr(cute_dsl_fused_moe_nvfp4, '_supported_ccs')}")
print(f"  has is_compute_capability_supported: {hasattr(cute_dsl_fused_moe_nvfp4, 'is_compute_capability_supported')}")
print(f"  has is_backend_supported: {hasattr(cute_dsl_fused_moe_nvfp4, 'is_backend_supported')}")
EOF

echo
echo "=== Compare with a function that uses backend_requirement ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '/repo')

from flashinfer.quantization.fp4_quantization import nvfp4_kv_dequantize

print("nvfp4_kv_dequantize (uses `@backend_requirement`) attributes:")
print(f"  has _supported_ccs: {hasattr(nvfp4_kv_dequantize, '_supported_ccs')}")
print(f"  has is_compute_capability_supported: {hasattr(nvfp4_kv_dequantize, 'is_compute_capability_supported')}")
print(f"  has is_backend_supported: {hasattr(nvfp4_kv_dequantize, 'is_backend_supported')}")
EOF

echo
echo "=== Summary: Check decorator stacking on functions in fused_moe.py ==="
rg -B1 'def cute_dsl_fused_moe_nvfp4\|def cute_dsl_fused_moe_fp4_block' flashinfer/fused_moe/cute_dsl/fused_moe.py | grep -E '(@|def )' | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 583


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Direct code examination: Check standalone function decorators ==="
sed -n '683,690p' flashinfer/fused_moe/cute_dsl/fused_moe.py

echo
echo "=== Direct code examination: Check function with backend_requirement for comparison ==="
sed -n '1274,1280p' flashinfer/quantization/fp4_quantization.py

echo
echo "=== Key finding: Verify learning requirement one more time ==="
cat << 'EOF'
From the retrieval learning:
"Use `backend_requirement` decorator on APIs that have compute capability requirements 
and provide `is_compute_capability_supported(cc)` and `is_backend_supported()` methods"

Current code patterns observed:
1. Standalone function (cute_dsl_fused_moe_nvfp4): Uses `@supported_compute_capability` directly
2. Wrapper function (nvfp4_kv_dequantize): Uses `@backend_requirement` with common_check

CuteDslMoEWrapper.__init__ currently uses:
  `@supported_compute_capability`([100, 103])
  `@flashinfer_api`

This does NOT match the learning requirement for APIs to use `@backend_requirement`
EOF

echo
echo "=== Conclusion: Review the original concern ==="
echo "The original review comment states:"
echo "  'Wrapping CuteDslMoEWrapper.__init__ with the compute-capability decorator'"
echo "  'only guards construction; it may not make CuteDslMoEWrapper itself expose'"
echo "  'the standard backend discovery helpers'"
echo ""
echo "This is CORRECT. The `@supported_compute_capability` decorator adds helpers to"
echo "the __init__ function object, not to the class instance. Following the learning,"
echo "this should use `@backend_requirement` instead."

Repository: flashinfer-ai/flashinfer

Length of output: 1874


Use @backend_requirement instead of @supported_compute_capability on CuteDslMoEWrapper.__init__.

The @supported_compute_capability decorator provides only is_compute_capability_supported on the method object. Public APIs should use @backend_requirement to expose both is_compute_capability_supported(cc) and is_backend_supported() methods, following the library's standard backend-discovery pattern.

Move the compute capability constraint into a separate check function decorated with @supported_compute_capability, then apply @backend_requirement({}, common_check=...) to __init__, mirroring the pattern used by nvfp4_kv_dequantize and other public APIs.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` around lines 342 - 343, Replace
the `@supported_compute_capability` decorator on CuteDslMoEWrapper.__init__ with
`@backend_requirement` so the public API exposes both
is_compute_capability_supported(cc) and is_backend_supported(); to do this,
create a separate small check function (e.g., cute_dsl_moe_compute_check)
decorated with `@supported_compute_capability`([100, 103]) that returns the
compute-capability check, then apply `@backend_requirement`({},
common_check=cute_dsl_moe_compute_check) and `@flashinfer_api` to
CuteDslMoEWrapper.__init__ (mirroring the nvfp4_kv_dequantize pattern) so
backend discovery and compute-capability checks are both available.

@aleozlx aleozlx merged commit 6f0928c into flashinfer-ai:main Mar 19, 2026
31 checks passed
aleozlx added a commit that referenced this pull request Mar 24, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

fix api breaking changes for 0.6.7 release

## 🔍 Related Issues (Gated-by PRs)


https://github.com/flashinfer-ai/flashinfer/issues?q=state%3Aopen%20label%3Av0.6.7

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

**API changes review**

API changes since v0.6.6

  PR #2520 + commit e35c19e (fixed to be compatible)

  Function: xqa()
Change: Added k_sf_cache=None, v_sf_cache=None as keyword-only params
(after *). Backward-compatible.

  PR #2618 (has PR #2730 to fix it)

  Function: gated_delta_rule_mtp()
Change: disable_state_update: bool = True → Optional[bool] = None. Still
defaults to True at runtime but emits a deprecation
  warning; will flip to False in 0.7.0.

  PR #2775 (expected — cute DSL MoE cleanup)

  Function: blockscaled_contiguous_grouped_gemm_nvfp4()
  Change: Entire @flashinfer_api decorated function deleted.

  Function: blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4()
  Change: Entire @flashinfer_api decorated function deleted.

Function:
blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4()
Change: @flashinfer_api decorator removed; added enable_pdl: bool = True
param.

  Function: blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4()
Change: @flashinfer_api decorator removed; added enable_pdl: bool = True
param.

  Function: CuteDslMoEWrapper.__init__()
  Change: Added enable_pdl: bool = True param. Backward-compatible.

  Function: cute_dsl_fused_moe_nvfp4()
  Change: Added enable_pdl: bool = True param. Backward-compatible.

  PR #2428

  Function: rmsnorm_quant()
Change: scale: float → scale: Union[float, torch.Tensor]; return type
torch.Tensor → None.

  Function: fused_add_rmsnorm_quant()
  Change: scale: float → scale: Union[float, torch.Tensor].

  Quantization functions (relocated, not removed)

All quantization APIs (fp4_quantize, block_scale_interleave,
e2m1_and_ufp8sf_scale_to_float, shuffle_matrix_a, shuffle_matrix_sf_a,
nvfp4_quantize, nvfp4_batched_quantize, scaled_fp4_grouped_quantize,
mxfp4_quantize, mxfp4_dequantize, mxfp4_dequantize_host,
mxfp8_quantize, mxfp8_dequantize_host) were moved from
flashinfer/fp4_quantization.py and flashinfer/fp8_quantization.py to
flashinfer/quantization/. Signatures, @flashinfer_api decorators, and
__init__.py exports are preserved. No breakage.

```diff
$ git diff v0.6.6 | grep -A20 "@flashinfer_api"                                               
     @flashinfer_api
@@ -1215,6 +1227,9 @@ class BatchDecodeWithPagedKVCacheWrapper:
         sinks: Optional[torch.Tensor] = None,
         q_len_per_req: Optional[int] = 1,
         skip_softmax_threshold_scale_factor: Optional[float] = None,
+        kv_block_scales: Optional[
+            Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
+        ] = None,
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         r"""Compute batch decode attention between query and paged kv cache.

@@ -1273,6 +1288,15 @@ class BatchDecodeWithPagedKVCacheWrapper:
             enable_pdl = device_support_pdl(q.device)
         k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout)

+        # Unpack kv_block_scales
+        key_block_scales = None
+        value_block_scales = None
+        if kv_block_scales is not None:
+            if isinstance(kv_block_scales, tuple):
+                key_block_scales, value_block_scales = kv_block_scales
--
-@flashinfer_api
-def fp4_quantize(
-    input: torch.Tensor,
-    global_scale: Optional[torch.Tensor] = None,
-    sf_vec_size: int = 16,
-    sf_use_ue8m0: bool = False,
-    is_sf_swizzled_layout: bool = True,
-    is_sf_8x4_layout: bool = False,
-    enable_pdl: Optional[bool] = None,
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    """Quantize input tensor to FP4 format.
-
-    This function implements FP4 quantization that converts input tensors to a compressed FP4 format
-    with associated scale factors. It supports various input data types and scale factor layouts.
-
-    Args:
-        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
-        global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-        sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False.
-        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
--
-@flashinfer_api
-def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor:
-    """Swizzle block scale tensor for FP4 format.
-
-    This function swizzles the block scale tensor to optimize memory access patterns
-    for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128.
-
-    Args:
-        unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16.
-
-    Returns:
-        torch.Tensor: Swizzled tensor with the same shape as input.
-
-    Raises:
-        AssertionError: If input dtype is not uint8 or bfloat16.
-    """
-    # TODO(shuw): check input dtype is uint8
-    assert (
-        unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16
-    ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}"
-
--
-@flashinfer_api
-def e2m1_and_ufp8sf_scale_to_float(
-    e2m1_tensor: torch.Tensor,
-    ufp8_scale_tensor: torch.Tensor,
-    global_scale_tensor: Optional[torch.Tensor] = None,
-    sf_vec_size: int = 16,
-    ufp8_type: int = 1,
-    is_sf_swizzled_layout: bool = True,
-) -> torch.Tensor:
-    """Convert E2M1 format tensor and UFP8 scale factors to float tensor.
-
-    This function performs dequantization by converting a packed FP4 tensor in E2M1 format
-    back to float values using the associated UFP8 scale factors and global scale.
-
-    Args:
-        e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8.
-        ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8.
-        global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-        ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1.
-        is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
--
-@flashinfer_api
-def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor:
-    """
-    PyTorch equivalent of trtllm-gen `shuffleMatrixA`
-    """
-    row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m)
-
-    return input_tensor[row_indices.to(input_tensor.device)]
-
-
-@flashinfer_api
-def shuffle_matrix_sf_a(
-    input_tensor: torch.Tensor,
-    epilogue_tile_m: int,
-    num_elts_per_sf: int = 16,
-):
-    """
-    Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat.
-    `shuffleMatrixSfA` expects the input to be in 128x4 layout and then
-    apply the same shuffling in `shuffleMatrixA` and writes out in 128x4
-    layout.
-    This function expects the input to be in linear layout. It's done this
-    way because the scaling factors in the NVFP4 checkpoints are quantized
-    and are in linear layout.
-    This function doesn't add padding.
-    """
-
-    row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m)
-
-    w_shuffled = input_tensor[row_indices.to(input_tensor.device)]
-
--
-@flashinfer_api
-def nvfp4_quantize(
-    a,
-    a_global_sf,
-    sfLayout=SfLayout.layout_128x4,
-    do_shuffle=False,
-    sf_vec_size=16,
-    enable_pdl=None,
-):
-    """
-    Quantize input tensor to NVFP4 format.
-
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
-        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
-        sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4.
-        do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
-            If None, automatically detects based on device capability. Defaults to None.
-
--
-@flashinfer_api
-def mxfp4_quantize(a):
-    """
-    Quantize input tensor to MXFP4 format.
-
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
-
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
-            - Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
-    """
-    a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max()
-    a_fp4, a_sf = fp4_quantize(a.cuda(), a_global_sf.cuda(), 32, True, True)
-    return a_fp4, a_sf
-
-
-@flashinfer_api
-def mxfp4_dequantize(a_fp4, a_sf):
-    """
-    Dequantize input tensor from MXFP4 format.
-
-    Parameters:
-        a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
-        a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
-
-    Returns:
-        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
-    """
-    return e2m1_and_ufp8sf_scale_to_float(
-        a_fp4.cpu().view(torch.uint8),
-        a_sf.cpu().view(torch.uint8).reshape(-1),
-        torch.tensor([1.0], device=a_fp4.device),
-        32,
-        0,
-        True,
-    )
-
--
-@flashinfer_api
-def mxfp4_dequantize_host(
-    weight: torch.Tensor,
-    scale: torch.Tensor,
-    group_size: int = 32,
-) -> torch.Tensor:
-    """
-    Dequantize input tensor from MXFP4 format on host.
-
-    Parameters:
-        weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
-        scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
-        group_size (int, optional): Group size for dequantization. Defaults to 32.
-
-    Returns:
-        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
-    """
-    # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future
-    major, minor = get_compute_capability(
-        torch.device("cuda:0")
-    )  # use any cuda device to get a compute capability
--
-@flashinfer_api
-def nvfp4_batched_quantize(
-    a,
-    a_global_sf,
-    sf_vec_size=16,
-):
-    """
-    Quantize batched input tensor to NVFP4 format.
-
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
-        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
-            - Scale factors tensor with shape determined by layout and sf_vec_size
-    """
-    major, minor = get_compute_capability(a.device)
-    device_arch = f"{major * 10 + minor}"
--
-@flashinfer_api
-def scaled_fp4_grouped_quantize(
-    a,
-    mask,
-    a_global_sf,
-):
-    """
-    quantize batched input tensor to NVFP4 format with mask.
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
-        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
-        mask (torch.Tensor): Mask tensor to apply before quantization.
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
-            - Scale factors tensor with shape determined by layout and sf_vec_size
-    """
-    major, minor = get_compute_capability(a.device)
-    device_arch = f"{major * 10 + minor}"
-    a_fp4, a_sf = get_fp4_quantization_module(
-        device_arch
--
-@flashinfer_api
-def mxfp8_quantize(
-    input: torch.Tensor,
-    is_sf_swizzled_layout: bool = True,
-    alignment: int = 32,
-    enable_pdl: Optional[bool] = None,
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    """Quantize input tensor to MxFP8 format.
-
-    This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format
-    with associated scale factors. It supports various input data types and scale factor layouts.
-
-    Args:
-        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
-        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
-        alignment (int, optional): sfVecSize. Defaults to 32.
-        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
-            If None, automatically detects based on device capability. Defaults to None.
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3
--
-@flashinfer_api
-def mxfp8_dequantize_host(
-    input: torch.Tensor,
-    scale_tensor: torch.Tensor,
-    is_sf_swizzled_layout: bool = True,
-) -> torch.Tensor:
-    """Dequantize input tensor from MxFP8 format.
-
-    This function performs dequantization by converting a packed FP8 tensor in MxFP8 format
-    back to float values using the associated scale factors.
-
-    Args:
-        input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3.
-        scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size.
-        is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
-
-    Returns:
-        torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32.
-
-    """
-
--
-@flashinfer_api
 def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -323,6 +324,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
     vectorized_f32: bool = True,
     raster_along_m: bool = False,
     sm_count: Optional[int] = None,
+    enable_pdl: bool = True,
 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
     """Blockscaled Contiguous Gather Grouped GEMM with SwiGLU Fusion for MoE workloads.

@@ -423,7 +425,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
     major, minor = get_compute_capability(a.device)
     if major != 10:
         raise ValueError(
-            f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103, SM110). "
+            f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103). "
             f"Got SM{major}{minor}."
         )

--
-@flashinfer_api
-def blockscaled_contiguous_grouped_gemm_nvfp4(
-    a: torch.Tensor,
-    b: torch.Tensor,
-    a_scale: torch.Tensor,
-    b_scale: torch.Tensor,
-    alpha: torch.Tensor,
-    tile_idx_to_group_idx: torch.Tensor,
-    num_non_exiting_tiles: torch.Tensor,
-    out: Optional[torch.Tensor] = None,
-    *,
-    ab_dtype: str = "float4_e2m1fn",
-    sf_dtype: str = "float8_e4m3fn",
-    c_dtype: str = "bfloat16",
-    sf_vec_size: int = 16,
-    mma_tiler_mn: Tuple[int, int] = (128, 128),
-    cluster_shape_mn: Tuple[int, int] = (1, 1),
-    sm_count: Optional[int] = None,
-) -> torch.Tensor:
-    """Blockscaled Contiguous Grouped GEMM for MoE workloads with NVFP4 quantization.
-
--
-@flashinfer_api
 def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -272,6 +279,7 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
     cluster_shape_mn: Tuple[int, int] = (2, 1),
     raster_along_m: bool = False,
     sm_count: Optional[int] = None,
+    enable_pdl: bool = True,
 ) -> torch.Tensor:
     """Blockscaled Contiguous Grouped GEMM with Finalize Fusion for MoE workloads.

@@ -298,7 +306,11 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
             expanded_idx = token_idx * topk + topk_idx. Invalid rows have -1.
         token_final_scales: Router scaling factors, shape (seq_len, topk), float32/bf16/fp16
         out: Optional output tensor, shape (seq_len, n). Created if None.
-             This tensor is used for atomic accumulation, so it should be zero-initialized.
+             This tensor is used for atomic accumulation. If `out` is
+             provided, it must already be zero-initialized by the caller.
+             If `out` is None, this function allocates a zero-initialized
+             output tensor. Passing a non-zeroed `out` buffer will silently
--
-@flashinfer_api
-def blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4(
-    a: torch.Tensor,
-    b: torch.Tensor,
-    a_scale: torch.Tensor,
-    b_scale: torch.Tensor,
-    alpha: torch.Tensor,
-    tile_idx_to_group_idx: torch.Tensor,
-    num_non_exiting_tiles: torch.Tensor,
-    out: Optional[torch.Tensor] = None,
-    out_scale: Optional[torch.Tensor] = None,
-    global_scale: Optional[torch.Tensor] = None,
-    *,
-    ab_dtype: str = "float4_e2m1fn",
-    sf_dtype: str = "float8_e4m3fn",
-    c_dtype: str = "bfloat16",
-    sf_vec_size: int = 16,
-    mma_tiler_mn: Tuple[int, int] = (256, 128),
-    cluster_shape_mn: Tuple[int, int] = (2, 1),
-    vectorized_f32: bool = True,
-    sm_count: Optional[int] = None,
--
     @flashinfer_api
     def __init__(
         self,
@@ -347,6 +355,7 @@ class CuteDslMoEWrapper:
         sf_vec_size: int = 16,
         output_dtype: torch.dtype = torch.bfloat16,
         device: str = "cuda",
+        enable_pdl: bool = True,
     ):
         """Initialize the MoE wrapper.

@@ -363,6 +372,7 @@ class CuteDslMoEWrapper:
             sf_vec_size: Scale factor vector size. Default: 16.
             output_dtype: Output data type. Default: torch.bfloat16.
             device: Device for buffer allocation. Default: "cuda".
+            enable_pdl: Enable Programmatic Dependent Launch. Default: True.
         """
         self.num_experts = num_experts
         self.top_k = top_k
@@ -376,6 +386,7 @@ class CuteDslMoEWrapper:
         self.sf_vec_size = sf_vec_size
--
     @flashinfer_api
@@ -550,9 +570,10 @@ class CuteDslMoEWrapper:
                 f"num_tokens ({num_tokens}) exceeds max_num_tokens ({self.max_num_tokens})"
             )

-        # Allocate output buffer if not using pre-allocated one
+        # Slice the pre-allocated buffer to the active batch so that
+        # _moe_core_impl only zeros num_tokens rows, not max_num_tokens.
         if self.use_cuda_graph:
-            moe_output = self._moe_output
+            moe_output = self._moe_output[:num_tokens]
         else:
             moe_output = torch.empty(
                 (num_tokens, self.hidden_size),
@@ -627,6 +648,7 @@ def _cute_dsl_fused_moe_nvfp4_impl(
     use_fused_finalize: bool = True,
     moe_output: Optional[torch.Tensor] = None,
     aux_stream: Optional[torch.cuda.Stream] = None,
+    enable_pdl: bool = True,
 ) -> torch.Tensor:
     """Internal implementation called by auto-tuner for functional API."""
--
 @flashinfer_api
 def cute_dsl_fused_moe_nvfp4(
     x: torch.Tensor,
@@ -678,9 +702,12 @@ def cute_dsl_fused_moe_nvfp4(
     use_fused_finalize: bool = True,
     moe_output: Optional[torch.Tensor] = None,
     aux_stream: Optional[torch.cuda.Stream] = None,
+    enable_pdl: bool = True,
 ) -> torch.Tensor:
     """Run fused MoE computation using CuteDSL NVFP4 kernels.

+    Supported architectures: SM100, SM103.
+
     This is the simple functional API. For CUDA graph support, use
     `CuteDslMoEWrapper` instead.

@@ -736,6 +763,7 @@ def cute_dsl_fused_moe_nvfp4(
         local_expert_offset=local_expert_offset,
         use_fused_finalize=use_fused_finalize,
         output_dtype=output_dtype,
+        enable_pdl=enable_pdl,
--
 @flashinfer_api
 def gated_delta_rule_decode_pretranspose(
     q: torch.Tensor,
@@ -1002,8 +174,9 @@ def gated_delta_rule_decode_pretranspose(
         - State layout is v-major (K-last): [B, HV, V, K]. When state is bfloat16
           and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used
           (supports both the direct ``state`` path and the pool+indices path).
-        - pool+indices (``initial_state``/``initial_state_indices``) only supported
-          via the bf16 fast path; float32 state raises an error.
+        - pool+indices (``initial_state``/``initial_state_indices``) supported on
+          both the bf16 fast path (T in 1..4, K=V=128) and the float32 legacy path
+          (T=1). The float32 path also supports negative indices for padding.
         - Legacy path (float32 state, T=1): K and V must be multiples of 4.
     """
     # Validate input shapes
@@ -1069,13 +242,17 @@ def gated_delta_rule_decode_pretranspose(
         return_state = initial_state if use_pool else state
         return output, return_state

-    # Legacy path: T=1 only, float32 state (no pool+indices support)
-    assert not use_pool, (
--
 @flashinfer_api
 def gated_delta_rule_mtp(
     q: torch.Tensor,
@@ -2427,7 +489,7 @@ def gated_delta_rule_mtp(
     scale: Optional[float] = None,
     output: Optional[torch.Tensor] = None,
     intermediate_states_buffer: Optional[torch.Tensor] = None,
-    disable_state_update: bool = True,
+    disable_state_update: Optional[bool] = None,
     use_qk_l2norm: bool = True,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     """
@@ -2463,8 +525,15 @@ def gated_delta_rule_mtp(
         intermediate_states_buffer (Optional[torch.Tensor]):
             Buffer for caching intermediate states, shape ``[pool_size, T, HV, V, K]``.
             If None, intermediate states are not cached.
-        disable_state_update (bool):
-            If True, the initial state is not updated. Default: ``True``.
+        disable_state_update (Optional[bool]):
+            If True, the initial state is not updated. Currently defaults to ``True``.
+            Please pass this argument explicitly — the default will change to ``False``
--
 @flashinfer_api
@@ -60,16 +120,14 @@ def rmsnorm(
     output: torch.Tensor
         Normalized tensor, 2D shape (batch_size, hidden_size) or 3D shape (batch_size, num_heads, hidden_size).
     """
-    if enable_pdl is None:
-        enable_pdl = device_support_pdl(input.device)
     if out is None:
         out = torch.empty_like(input)
-    _rmsnorm(out, input, weight, eps, enable_pdl)
+    _rmsnorm_impl(out, input, weight, eps, enable_pdl)
     return out


 @register_custom_op("flashinfer::rmsnorm", mutates_args=("out",))
-def _rmsnorm(
+def _rmsnorm_impl(
     out: torch.Tensor,
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -78,11 +136,21 @@ def _rmsnorm(
--
 @flashinfer_api
 def fmha_v2_prefill_deepseek(
     query: torch.Tensor,
@@ -3865,18 +4029,11 @@ def fmha_v2_prefill_deepseek(
         If return_lse is False, the output will be a single tensor.
     """
     if not is_sm12x_supported(query.device):
-        major, minor = get_compute_capability(query.device)
-        if major == 12:
-            min_cuda = "13.0" if minor >= 1 else "12.8"
-            raise ValueError(
-                f"fmha_v2_prefill_deepseek requires CUDA >= {min_cuda} "
-                f"for SM12{minor}x GPUs."
-            )
         raise ValueError("fmha_v2_prefill_deepseek is only supported on SM12x GPUs.")
     assert query.shape[3] == 192 and key.shape[3] == 192 and value.shape[3] == 128, (
         "currently only support deepseek r1 192 query and 128 value"
     )
-    module = get_trtllm_fmha_v2_module()
+    module = get_trtllm_fmha_v2_sm120_module()
     is_e4m3 = query.dtype == torch.float8_e4m3fn
--
+@flashinfer_api
+def trtllm_fmha_v2_prefill(
+    qkv: Union[
+        torch.Tensor,
+        Tuple[torch.Tensor, torch.Tensor],
+        Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+    ],
+    input_layout: str,
+    workspace_buffer: torch.Tensor,
+    seq_lens: torch.Tensor,
+    max_q_len: int,
+    max_kv_len: int,
+    bmm1_scale: float,
+    bmm2_scale: float,
+    batch_size: int,
+    cum_seq_lens_q: torch.Tensor,
+    cum_seq_lens_kv: torch.Tensor,
+    block_tables: Optional[torch.Tensor] = None,
+    out: Optional[torch.Tensor] = None,
+    out_dtype: Optional[Union[torch.dtype, str]] = None,
+    sinks: Optional[List[torch.Tensor]] = None,
--
+@flashinfer_api
+def fp4_quantize(
+    input: torch.Tensor,
+    global_scale: Optional[torch.Tensor] = None,
+    sf_vec_size: int = 16,
+    sf_use_ue8m0: bool = False,
+    is_sf_swizzled_layout: bool = True,
+    is_sf_8x4_layout: bool = False,
+    enable_pdl: Optional[bool] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Quantize input tensor to FP4 format.
+
+    This function implements FP4 quantization that converts input tensors to a compressed FP4 format
+    with associated scale factors. It supports various input data types and scale factor layouts.
+
+    Args:
+        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
+        global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+        sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False.
+        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
--
+@flashinfer_api
+def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor:
+    """Swizzle block scale tensor for FP4 format.
+
+    This function swizzles the block scale tensor to optimize memory access patterns
+    for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128.
+
+    Args:
+        unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16.
+
+    Returns:
+        torch.Tensor: Swizzled tensor with the same shape as input.
+
+    Raises:
+        AssertionError: If input dtype is not uint8 or bfloat16.
+    """
+    # TODO(shuw): check input dtype is uint8
+    assert (
+        unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16
+    ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}"
+
--
+@flashinfer_api
+def e2m1_and_ufp8sf_scale_to_float(
+    e2m1_tensor: torch.Tensor,
+    ufp8_scale_tensor: torch.Tensor,
+    global_scale_tensor: Optional[torch.Tensor] = None,
+    sf_vec_size: int = 16,
+    ufp8_type: int = 1,
+    is_sf_swizzled_layout: bool = True,
+) -> torch.Tensor:
+    """Convert E2M1 format tensor and UFP8 scale factors to float tensor.
+
+    This function performs dequantization by converting a packed FP4 tensor in E2M1 format
+    back to float values using the associated UFP8 scale factors and global scale.
+
+    Args:
+        e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8.
+        ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8.
+        global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+        ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1.
+        is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
--
+@flashinfer_api
+def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor:
+    """
+    PyTorch equivalent of trtllm-gen `shuffleMatrixA`
+    """
+    row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m)
+
+    return input_tensor[row_indices.to(input_tensor.device)]
+
+
+@flashinfer_api
+def shuffle_matrix_sf_a(
+    input_tensor: torch.Tensor,
+    epilogue_tile_m: int,
+    num_elts_per_sf: int = 16,
+):
+    """
+    Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat.
+    `shuffleMatrixSfA` expects the input to be in 128x4 layout and then
+    apply the same shuffling in `shuffleMatrixA` and writes out in 128x4
+    layout.
+    This function expects the input to be in linear layout. It's done this
+    way because the scaling factors in the NVFP4 checkpoints are quantized
+    and are in linear layout.
+    This function doesn't add padding.
+    """
+
+    row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m)
+
+    w_shuffled = input_tensor[row_indices.to(input_tensor.device)]
+
--
+@flashinfer_api
+def nvfp4_quantize(
+    a,
+    a_global_sf,
+    sfLayout=SfLayout.layout_128x4,
+    do_shuffle=False,
+    sf_vec_size=16,
+    enable_pdl=None,
+):
+    """
+    Quantize input tensor to NVFP4 format.
+
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
+        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
+        sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4.
+        do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
+            If None, automatically detects based on device capability. Defaults to None.
+
--
+@flashinfer_api
+def mxfp4_quantize(
+    a: torch.Tensor,
+    backend: str = "cuda",
+    enable_pdl: Optional[bool] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize input tensor to MXFP4 format.
+
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
+        backend (str, optional): Backend to use for quantization.
+            - "cuda": Use CUDA kernel (default, stable)
+            - "cute-dsl": Use CuTe-DSL kernel (requires SM100+, **experimental**)
+        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic
+            Dependent Launch). Only used when backend="cute-dsl".
+            If None, automatically detects based on device capability.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+            - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
--
+@flashinfer_api
+def mxfp4_dequantize(a_fp4, a_sf):
+    """
+    Dequantize input tensor from MXFP4 format.
+
+    Parameters:
+        a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
+        a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
+
+    Returns:
+        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
+    """
+    return e2m1_and_ufp8sf_scale_to_float(
+        a_fp4.cpu().view(torch.uint8),
+        a_sf.cpu().view(torch.uint8).reshape(-1),
+        torch.tensor([1.0], device=a_fp4.device),
+        32,
+        0,
+        True,
+    )
+
--
+@flashinfer_api
+def mxfp4_dequantize_host(
+    weight: torch.Tensor,
+    scale: torch.Tensor,
+    group_size: int = 32,
+) -> torch.Tensor:
+    """
+    Dequantize input tensor from MXFP4 format on host.
+
+    Parameters:
+        weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
+        scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
+        group_size (int, optional): Group size for dequantization. Defaults to 32.
+
+    Returns:
+        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
+    """
+    # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future
+    major, minor = get_compute_capability(
+        torch.device("cuda:0")
+    )  # use any cuda device to get a compute capability
--
+@flashinfer_api
+def nvfp4_batched_quantize(
+    a,
+    a_global_sf,
+    sf_vec_size=16,
+):
+    """
+    Quantize batched input tensor to NVFP4 format.
+
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
+        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
+            - Scale factors tensor with shape determined by layout and sf_vec_size
+    """
+    major, minor = get_compute_capability(a.device)
+    device_arch = f"{major * 10 + minor}"
--
+@flashinfer_api
+def nvfp4_quantize_paged_kv_cache(
+    k_cache: torch.Tensor,
+    v_cache: torch.Tensor,
+    kv_layout: str = "HND",
+    k_global_sf: Optional[torch.Tensor] = None,
+    v_global_sf: Optional[torch.Tensor] = None,
+) -> Tuple[
+    Tuple[torch.Tensor, torch.Tensor],
+    Tuple[torch.Tensor, torch.Tensor],
+    float,
+    float,
+]:
+    """Quantize paged KV cache to NVFP4 format for trtllm-gen MHA.
+
+    Quantizes BF16/FP16 K/V caches to NVFP4 with two-level scaling
+    (global FP32 + per-block FP8), and swizzles scale factors
+    for the SM100 trtllm-gen MHA kernel layout.
+
+    Args:
+        k_cache: Key cache tensor.
--
+@flashinfer_api
+def scaled_fp4_grouped_quantize(
+    a,
+    mask,
+    a_global_sf,
+):
+    """
+    quantize batched input tensor to NVFP4 format with mask.
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
+        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
+        mask (torch.Tensor): Mask tensor to apply before quantization.
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
+            - Scale factors tensor with shape determined by layout and sf_vec_size
+    """
+    major, minor = get_compute_capability(a.device)
+    device_arch = f"{major * 10 + minor}"
+    a_fp4, a_sf = get_fp4_quantization_module(
+        device_arch
--
+@flashinfer_api
+def nvfp4_kv_dequantize(
+    fp4_data: torch.Tensor,
+    block_scales: torch.Tensor,
+    global_scale: torch.Tensor,
+    output_dtype: torch.dtype = torch.bfloat16,
+) -> torch.Tensor:
+    """GPU dequantization of NVFP4 KV cache data with linear block scale layout.
+
+    Requires SM80+.
+
+    Args:
+        fp4_data (torch.Tensor): Packed FP4 data of shape ``[M, K/2]`` with dtype uint8.
+        block_scales (torch.Tensor): Per-block FP8 E4M3 scales of shape ``[M, K/16]``
+            with dtype uint8.
+        global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32,
+            on the same CUDA device as fp4_data.
+        output_dtype (torch.dtype): Output dtype, either ``torch.bfloat16`` or ``torch.float16``.
+
+    Returns:
+        torch.Tensor: Dequantized tensor of shape ``[M, K]`` with the specified output dtype.
--
+@flashinfer_api
+def nvfp4_kv_quantize(
+    input: torch.Tensor,
+    global_scale: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """GPU quantization to NVFP4 KV cache format with linear block scale layout.
+
+    Requires SM100+ (Blackwell) for the cvt.rn.satfinite.e2m1x2.f32 PTX instruction.
+
+    Args:
+        input (torch.Tensor): Input tensor of shape [M, K] with dtype bf16 or fp16.
+            K must be divisible by 16.
+        global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32,
+            on the same CUDA device as input.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]:
+            - fp4_output: Packed FP4 data of shape ``[M, K/2]`` with dtype uint8.
+            - block_scales: Per-block FP8 E4M3 scales of shape ``[M, K/16]`` with dtype uint8.
+    """
+    M, K = input.shape
--
+@flashinfer_api
+def mxfp8_quantize(
+    input: torch.Tensor,
+    is_sf_swizzled_layout: bool = True,
+    alignment: int = 32,
+    enable_pdl: Optional[bool] = None,
+    backend: Literal["cuda", "cute-dsl"] = "cuda",
+    sf_swizzle_layout: Optional[SfLayout] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Quantize input tensor to MxFP8 format.
+
+    This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format
+    with associated scale factors. It supports various input data types and scale factor layouts.
+
+    Args:
+        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
+        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
+        alignment (int, optional): sfVecSize. Defaults to 32.
+        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
+            If None, automatically detects based on device capability (SM >= 9.0). Defaults to None.
+        backend (Literal["cuda", "cute-dsl"], optional): Backend to use for quantization. Options are:
--
+@flashinfer_api
+def mxfp8_dequantize_host(
+    input: torch.Tensor,
+    scale_tensor: torch.Tensor,
+    is_sf_swizzled_layout: bool = True,
+    sf_swizzle_layout: Optional[SfLayout] = None,
+) -> torch.Tensor:
+    """Dequantize input tensor from MxFP8 format.
+
+    This function performs dequantization by converting a packed FP8 tensor in MxFP8 format
+    back to float values using the associated scale factors.
+
+    Args:
+        input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3.
+        scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size.
+        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
+        sf_swizzle_layout (Optional[SfLayout], optional): Swizzle layout for scale factors.
+            If provided,it overrides is_sf_swizzled_layout. Defaults to None.
+            Available options are 1. SfLayout.layout_128x4; 2. SfLayout.layout_linear.
+
+    Returns:
--
+@flashinfer_api
+def mxfp4_quantize_cute_dsl(
+    input: torch.Tensor,
+    enable_pdl: bool | None = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize input tensor to MXFP4 format using CuTe-DSL kernel.
+
+    This is a GPU implementation matching FlashInfer's mxfp4_quantize() behavior:
+    - Global scale computed as (448 * 6) / max(|input|)
+    - UE8M0 scale factors
+    - E2M1 output format (4-bit, 2 values per byte)
+    - Swizzled (128x4) scale factor layout
+
+    The kernel is compiled once per (K, dtype, pdl) combination and handles
+    varying M (batch size) at runtime without recompilation.
+
+    Args:
+        input: Input tensor of shape [M, K] with dtype fp16/bf16
+        enable_pdl: Whether to enable PDL (Programmatic Dependent Launch).
+            If None, automatically detects based on device capability (SM >= 9.0).
--
+@flashinfer_api
+def mxfp8_quantize_cute_dsl(
+    input: torch.Tensor,
+    is_sf_swizzled_layout: bool = True,
+    alignment: int = 32,
+    enable_pdl: bool | None = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize input tensor to MXFP8 format using CuTe-DSL kernel.
+
+    This is a GPU implementation with dual-path optimization:
+    - LINEAR layout: SF-block based iteration (fast)
+    - SWIZZLED layout: Row-based iteration with padding fast path (optimized)
+
+    The kernel is compiled once per (K, dtype, pdl) combination and handles
+    varying M (batch size) at runtime without recompilation.
+
+    Args:
+        input: Input tensor of shape [M, K] with dtype fp16/bf16
+        is_sf_swizzled_layout: Whether to use 128x4 swizzled layout (True) or linear (False)
+        alignment: Alignment for K dimension (default 32, must be multiple of SF_VEC_SIZE)
```


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Enhancements**
* Normalization now accepts scale as either a float or tensor; passing a
float emits a deprecation warning and is auto-converted for
compatibility.
* Attention/decoding API: cache-scale parameters are now optional
keyword-only arguments with sensible defaults, simplifying common call
patterns.
* **Tests**
* Tests updated to match the adjusted attention/decoding call signature.
* **Chores**
  * Release version bumped to 0.6.7.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants