Skip to content

refactor: refactoring cuda code to cute-dsl (part 1)#2428

Merged
bkryu merged 20 commits intoflashinfer-ai:mainfrom
yzh119:cute-dsl-part-1
Mar 12, 2026
Merged

refactor: refactoring cuda code to cute-dsl (part 1)#2428
bkryu merged 20 commits intoflashinfer-ai:mainfrom
yzh119:cute-dsl-part-1

Conversation

@yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Jan 28, 2026

📌 Description

We prioritize using dsl for kernel development over cuda for faster JIT compilation speed.
This PR is the first series that refactors the simple normalization kernels to cute-dsl.

CUDA code should be ready to remove after we finish end-to-end testing.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • CuTe-DSL–accelerated normalization: RMSNorm (2D/3D), LayerNorm, fused add+RMSNorm, and FP8-quantized variants exposed for runtime use.
    • Shared norm utilities and JIT warmup to improve kernel readiness.
  • Chores

    • Runtime selection and fallback for CuTe-DSL/CUDA normalization with a visibility check.
  • Bug Fixes

    • Safer optional-dependency handling to avoid hard failures when CUDA/CuTe-DSL is unavailable.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request marks the initial phase of refactoring the project's normalization kernels to leverage CuTe-DSL, aiming to enhance JIT compilation speed and overall kernel performance. It introduces a comprehensive set of CuTe-DSL-based normalization kernels and integrates them into the existing API with a conditional dispatch mechanism, paving the way for more efficient GPU computations.

Highlights

  • Introduction of CuTe-DSL Kernels: A new module flashinfer/cute_dsl/norm.py has been added, introducing several normalization kernels implemented using NVIDIA's CuTe-DSL. This includes RMSNormKernel, QKRMSNormKernel (for 3D tensors), RMSNormQuantKernel (with FP8 quantization), FusedAddRMSNormKernel, FusedAddRMSNormQuantKernel, and LayerNormKernel.
  • Conditional Kernel Dispatch: The existing flashinfer/norm.py has been updated to conditionally dispatch normalization operations to either the legacy CUDA JIT implementations or the new CuTe-DSL kernels. This is controlled by the FLASHINFER_USE_CUDA_NORM environment variable, allowing for a gradual transition and fallback mechanism.
  • Enhanced Normalization Functionality: The new CuTe-DSL kernels provide optimized implementations for various normalization types, including standard RMSNorm, RMSNorm with FP8 quantization, fused add operations, and LayerNorm. The QKRMSNormKernel specifically addresses 3D tensors with arbitrary strides, improving flexibility.
  • Low-Level Optimizations: The CuTe-DSL implementations incorporate low-level optimizations such as 128-bit vectorized loads, two-stage warp/block reductions, and direct PTX intrinsics for operations like fast reciprocal, min/max, and single-byte FP8 conversion/storage, ensuring high performance and numerical stability.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 28, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds CuTe-DSL normalization kernels (RMSNorm, QK RMSNorm, FP8-quantized and fused variants, LayerNorm), new norm utilities, and runtime dispatch between CUDA-JIT and CuTe-DSL; conditions imports/exports on CuTe-DSL availability and adjusts some TVM-FFI host-call argument representations.

Changes

Cohort / File(s) Summary
CuTe-DSL package
flashinfer/cute_dsl/__init__.py
Re-export norm kernel classes and API functions from flashinfer.norm.kernels behind is_cute_dsl_available(); extend __all__.
Top-level exports
flashinfer/__init__.py
Wrap rmsnorm_fp4quant / add_rmsnorm_fp4quant imports in try/except and catch ImportError/AttributeError to avoid hard failures when CuTe-DSL is missing.
Norm package entry
flashinfer/norm/__init__.py
Add FLASHINFER_USE_CUDA_NORM flag and dispatch logic: always import gen_norm_module (CUDA JIT warmup) and choose between CUDA-JIT vs CuTe-DSL implementations; rename internal helpers and expose __all__ for public norm APIs.
Norm kernels package init
flashinfer/norm/kernels/__init__.py
New initializer re-exporting CuTe-DSL kernel classes and Python APIs for RMSNorm, QKRMSNorm, RMSNormQuant, FusedAddRMSNorm, FusedAddRMSNormQuant, and LayerNorm.
RMSNorm implementations
flashinfer/norm/kernels/rmsnorm.py
Add RMSNormKernel, QKRMSNormKernel, RMSNormQuantKernel, compiled-kernel factories, and Python APIs (rmsnorm_cute, qk_rmsnorm_cute, rmsnorm_quant_cute).
Fused Add + RMSNorm
flashinfer/norm/kernels/fused_add_rmsnorm.py
Add FusedAddRMSNormKernel, FusedAddRMSNormQuantKernel, compiled-kernel factories, and APIs (fused_add_rmsnorm_cute, fused_add_rmsnorm_quant_cute).
LayerNorm kernel
flashinfer/norm/kernels/layernorm.py
Add LayerNormKernel, TVM-FFI compiled-kernel factory, and layernorm_cute API with tiled/shared-memory implementation.
Norm utilities
flashinfer/norm/utils.py
New utilities: FP8 constants & conversion, PTX intrinsics, warp/block reductions, predicate helpers, vector/thread helpers, layout builder, dtype-to-string conversion; re-export cutlass utils; add __all__.
CuTe-DSL host-call tweaks
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py, flashinfer/cute_dsl/rmsnorm_fp4quant.py
Change TVM-FFI host-call invocation to pass plain Python int/float for scalar args (remove Int32(...) / Float32(...) wrappers).

Sequence Diagram(s)

sequenceDiagram
    participant App as Application
    participant NormAPI as flashinfer.norm
    participant Dispatcher as Dispatcher
    participant CUDAJIT as CUDA JIT (gen_norm_module)
    participant CuTeDSL as CuTe-DSL Path
    participant Compiled as Compiled Kernel (TVM/ptx)
    participant GPU as GPU Device

    App->>NormAPI: call rmsnorm(...)
    NormAPI->>Dispatcher: check FLASHINFER_USE_CUDA_NORM / is_cute_dsl_available()
    alt CUDA JIT selected
        Dispatcher->>CUDAJIT: request/jit module
        CUDAJIT->>Compiled: produce kernel
        Compiled->>GPU: execute kernel
    else CuTe-DSL selected
        Dispatcher->>CuTeDSL: request compiled CuTe kernel
        CuTeDSL->>Compiled: produce kernel (TVM-FFI)
        Compiled->>GPU: execute kernel
    end
    GPU-->>Compiled: result
    Compiled-->>NormAPI: output tensor
    NormAPI-->>App: return
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • kaixih
  • aleozlx
  • bkryu
  • jimmyzho
  • nvmbreughe

Poem

🐰 I hopped through tiles and shared-memory nooks,
I threaded warps and peeked in kernel books.
CuTe or CUDA, I choose which way to run —
Norms now dance on GPUs, bright as the sun. ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 56.69% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title accurately captures the main change: refactoring CUDA normalization kernels to CuTe-DSL implementation, which is the primary objective documented in the PR description and objectives.
Description check ✅ Passed PR description provides clear context on the refactoring goal (CUDA to CuTe DSL migration) and completion status of checklist items, but lacks specific details about file changes and key modifications.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@yzh119
Copy link
Collaborator Author

yzh119 commented Jan 28, 2026

/bot run

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant refactoring effort, moving normalization kernels from a custom CUDA JIT implementation to the CuTe-DSL. This is a commendable step towards improving performance and maintainability. The new flashinfer/cute_dsl/norm.py file is extensive and well-structured. My review has identified a few critical and high-severity issues that need to be addressed, including a bug in the FP8 quantization logic, incorrect API parameter naming, and inefficient shared memory usage. Once these issues are resolved, this will be a solid improvement.

.reg .b16 fp8_pair;
.reg .f32 zero;
mov.f32 zero, 0f00000000;
cvt.rn.satfinite.e4m3x2.f32 fp8_pair, zero, $0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a bug in the PTX inline assembly. The cvt.rn.satfinite.e4m3x2.f32 instruction converts the second source operand and stores it in the upper half of the destination register. The st.global.b8 instruction then stores the lower 8 bits of the register. As written, this will store the converted zero value, not the intended val ($0).

To fix this, you should swap the source operands in the cvt instruction to place the converted value in the lower half of the fp8_pair register.

Suggested change
cvt.rn.satfinite.e4m3x2.f32 fp8_pair, zero, $0;
cvt.rn.satfinite.e4m3x2.f32 fp8_pair, $0, zero;

Comment on lines +1391 to +1393
self.cols_per_tile_f32 * 4 * 2
+ self.cols_per_tile * elem_bytes * 2
+ 2 * self.num_warps * 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The shared memory calculation for LayerNormKernel includes space for gamma/beta in the input dtype, but these shared memory tiles (sGamma, sBeta) are allocated and partitioned but never actually used in the kernel. The kernel reads gamma and beta values directly from the float32 shared memory tiles (sGamma_f32, sBeta_f32).

This wastes a significant amount of shared memory, which can negatively impact performance by reducing occupancy.

You should remove the allocation of sGamma and sBeta (lines 1483-1492) and their partitioning (lines 1565-1566) in the kernel method, and update this shared memory size calculation.

Suggested change
self.cols_per_tile_f32 * 4 * 2
+ self.cols_per_tile * elem_bytes * 2
+ 2 * self.num_warps * 4
self.cols_per_tile_f32 * 4 * 2
+ 2 * self.num_warps * 4

Comment on lines +1770 to +1787
def tensor_api(
input: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
B: int,
N: int,
eps: float,
num_blocks: int,
) -> None:
compiled_kernel(
input,
weight,
output,
Int32(B),
Int32(N),
Float32(eps),
Int32(num_blocks),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The enable_pdl parameter is not being passed to the compiled kernel. The qk_rmsnorm_cute function accepts enable_pdl, but it's lost because the tensor_api wrapper doesn't accept it and pass it to the compiled_kernel call.

This is a bug that prevents Programmatic Dependent Launch from being used with this kernel. You should update tensor_api to accept enable_pdl and pass it through. You'll also need to update the call to kernel in qk_rmsnorm_cute (line 2087) to pass this new argument.

Suggested change
def tensor_api(
input: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
B: int,
N: int,
eps: float,
num_blocks: int,
) -> None:
compiled_kernel(
input,
weight,
output,
Int32(B),
Int32(N),
Float32(eps),
Int32(num_blocks),
)
def tensor_api(
input: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
B: int,
N: int,
eps: float,
enable_pdl: bool,
num_blocks: int,
) -> None:
compiled_kernel(
input,
weight,
output,
Int32(B),
Int32(N),
Float32(eps),
enable_pdl,
Int32(num_blocks),
)

Comment on lines +231 to +253
def predicate_k_3d(tXcX: cute.Tensor, limit: int) -> cute.Tensor:
"""Create predicate tensor for bounds checking (3D tensors).

For 3D tensors after local_tile, the last coordinate [2] is the head_dim dimension.
"""
tXpX = cute.make_rmem_tensor(
cute.make_layout(
(
cute.size(tXcX, mode=[0, 1]),
cute.size(tXcX, mode=[1]),
cute.size(tXcX, mode=[2]),
),
stride=(cute.size(tXcX, mode=[2]), 0, 1),
),
cutlass.Boolean,
)
for rest_v in cutlass.range_constexpr(tXpX.shape[0]):
for rest_k in cutlass.range_constexpr(tXpX.shape[2]):
# For 3D tensor, coordinate[2] is the head_dim index
tXpX[rest_v, 0, rest_k] = cute.elem_less(
tXcX[(0, rest_v), 0, rest_k][2], limit
)
return tXpX
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function predicate_k_3d is defined but does not appear to be used anywhere in the new code. It seems to be dead code and should be removed to improve code clarity and maintainability.


idX = cute.make_identity_tensor(mX.shape)
gX = cute.local_tile(mX, tiler_mn, (bidx, 0))
cute.local_tile(mY, tiler_mn, (bidx, 0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The result of this cute.local_tile call is not used, so this line has no effect and can be removed. A similar unused call exists in FusedAddRMSNormQuantKernel.kernel on line 1233.

@flashinfer-bot
Copy link
Collaborator

GitLab MR !272 has been created, and the CI pipeline #42732703 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/norm.py`:
- Around line 2044-2088: The qk_rmsnorm_cute function accepts enable_pdl but
never forwards it to the kernel compilation (kernel created via
_get_compiled_qk_rmsnorm_kernel uses a hardcoded value); update qk_rmsnorm_cute
to pass the enable_pdl flag into _get_compiled_qk_rmsnorm_kernel (or else remove
enable_pdl from qk_rmsnorm_cute's signature) so the compiled kernel respects PDL
support — locate the _get_compiled_qk_rmsnorm_kernel call in qk_rmsnorm_cute and
change its arguments to include enable_pdl (and ensure any downstream kernel
invocation/signature matches this added parameter).
🧹 Nitpick comments (8)
flashinfer/cute_dsl/norm.py (8)

858-862: Dead code: cute.local_tile(mY, ...) result is unused.

The result of cute.local_tile(mY, tiler_mn, (bidx, 0)) at line 860 is not assigned to a variable. The FP8 output is stored using PTX scalar stores later (lines 920-922), which access mY directly with computed offsets. This call appears to be unnecessary.

♻️ Proposed fix
         idX = cute.make_identity_tensor(mX.shape)
         gX = cute.local_tile(mX, tiler_mn, (bidx, 0))
-        cute.local_tile(mY, tiler_mn, (bidx, 0))
         cX = cute.local_tile(idX, tiler_mn, (bidx, 0))

1231-1236: Same issue: cute.local_tile(mY, ...) result is unused.

Same dead code pattern as in RMSNormQuantKernel.

♻️ Proposed fix
         idX = cute.make_identity_tensor(mX.shape)

-        cute.local_tile(mY, tiler_mn, (bidx, 0))
         gX = cute.local_tile(mX, tiler_mn, (bidx, 0))
         gR = cute.local_tile(mR, tiler_mn, (bidx, 0))
         cX = cute.local_tile(idX, tiler_mn, (bidx, 0))

1564-1567: Dead code: partition_D results are unused.

The results of thr_copy_load.partition_D(sGamma) and thr_copy_load.partition_D(sBeta) are not assigned to variables. Gamma/beta are loaded directly from sGamma_f32/sBeta_f32 at lines 1634-1635.

♻️ Proposed fix
-        # Partitions for gamma/beta (input dtype)
-        thr_copy_load.partition_D(sGamma)
-        thr_copy_load.partition_D(sBeta)
-
         # Register fragments - initialize to zero for proper handling of out-of-bounds threads

2016-2042: Missing @flashinfer_api decorator on public API function.

The rmsnorm_cute function is exported in __all__ and thus part of the public API, but it lacks the @flashinfer_api decorator required by coding guidelines.

Additionally, the enable_pdl parameter is accepted but completely ignored. The kernel is compiled with a hardcoded False value at line 1764. This breaks the API contract with callers who expect PDL to be honored.

♻️ Proposed fix for decorator
+from ..api_logging import flashinfer_api
+
+@flashinfer_api
 def rmsnorm_cute(
     input: torch.Tensor,

As per coding guidelines: "Use @flashinfer_api decorator for debugging API calls."


2090-2113: Same issues: missing @flashinfer_api decorator and unused enable_pdl.

rmsnorm_quant_cute has the same issues as rmsnorm_cute.


2116-2135: Same issues: missing @flashinfer_api decorator and unused enable_pdl.

fused_add_rmsnorm_cute has the same issues.


2138-2170: Same issues: missing @flashinfer_api decorator and unused enable_pdl.

fused_add_rmsnorm_quant_cute has the same issues.


2173-2192: Missing @flashinfer_api decorator.

layernorm_cute is missing the @flashinfer_api decorator. Note that this function doesn't have an enable_pdl parameter, which is consistent since it doesn't expose PDL functionality.

Comment on lines +2044 to +2088
def qk_rmsnorm_cute(
input: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-6,
weight_bias: float = 0.0,
enable_pdl: bool = False,
) -> None:
"""CuTe DSL QKRMSNorm for 3D tensors [batch, heads, head_dim].

Supports arbitrary stride - no need to call contiguous().
Each warp processes one (batch, head) pair independently using warp-only reduction.

Args:
input: Input tensor of shape [batch_size, num_heads, head_dim].
Last dimension must be contiguous (stride[-1] == 1).
weight: Weight tensor of shape [head_dim].
output: Output tensor (same shape as input).
eps: Small constant for numerical stability.
weight_bias: Bias added to weight (0 for standard RMSNorm, 1 for Gemma).
enable_pdl: Enable Programmatic Dependent Launch for SM90+ GPUs.
"""
assert input.dim() == 3, "QKRMSNorm expects 3D input [batch, heads, head_dim]"

batch_size, num_heads, head_dim = input.shape
M = batch_size * num_heads

# Kernel configuration
num_warps = 4

# Calculate grid size based on SM count and estimated occupancy
num_sms = get_num_sm(input.device)
blocks_per_sm = 16 # Theoretical max for 128-thread blocks
max_blocks = num_sms * blocks_per_sm
needed_blocks = (M + num_warps - 1) // num_warps
num_blocks = min(max_blocks, needed_blocks)

dtype_str = _torch_dtype_to_str(input.dtype)
kernel = _get_compiled_qk_rmsnorm_kernel(
dtype_str, head_dim, weight_bias, num_warps
)

# Pass 3D tensors directly - kernel handles arbitrary stride
kernel(input, weight, output, batch_size, num_heads, eps, num_blocks)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

enable_pdl parameter is accepted but not effectively used.

The qk_rmsnorm_cute function accepts enable_pdl but the compiled kernel at line 1764 uses a hardcoded enable_pdl=False. The kernel supports PDL (lines 617-618, 747-748), but the parameter isn't being passed through during compilation.

🔧 Proposed fix to support PDL

To properly support PDL, the compilation would need to be done at runtime with the actual enable_pdl value, or the parameter should be removed from the API signature if PDL is intentionally disabled for CuTe-DSL kernels.

If PDL is intentionally disabled, consider removing the parameter:

 def qk_rmsnorm_cute(
     input: torch.Tensor,
     weight: torch.Tensor,
     output: torch.Tensor,
     eps: float = 1e-6,
     weight_bias: float = 0.0,
-    enable_pdl: bool = False,
 ) -> None:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def qk_rmsnorm_cute(
input: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-6,
weight_bias: float = 0.0,
enable_pdl: bool = False,
) -> None:
"""CuTe DSL QKRMSNorm for 3D tensors [batch, heads, head_dim].
Supports arbitrary stride - no need to call contiguous().
Each warp processes one (batch, head) pair independently using warp-only reduction.
Args:
input: Input tensor of shape [batch_size, num_heads, head_dim].
Last dimension must be contiguous (stride[-1] == 1).
weight: Weight tensor of shape [head_dim].
output: Output tensor (same shape as input).
eps: Small constant for numerical stability.
weight_bias: Bias added to weight (0 for standard RMSNorm, 1 for Gemma).
enable_pdl: Enable Programmatic Dependent Launch for SM90+ GPUs.
"""
assert input.dim() == 3, "QKRMSNorm expects 3D input [batch, heads, head_dim]"
batch_size, num_heads, head_dim = input.shape
M = batch_size * num_heads
# Kernel configuration
num_warps = 4
# Calculate grid size based on SM count and estimated occupancy
num_sms = get_num_sm(input.device)
blocks_per_sm = 16 # Theoretical max for 128-thread blocks
max_blocks = num_sms * blocks_per_sm
needed_blocks = (M + num_warps - 1) // num_warps
num_blocks = min(max_blocks, needed_blocks)
dtype_str = _torch_dtype_to_str(input.dtype)
kernel = _get_compiled_qk_rmsnorm_kernel(
dtype_str, head_dim, weight_bias, num_warps
)
# Pass 3D tensors directly - kernel handles arbitrary stride
kernel(input, weight, output, batch_size, num_heads, eps, num_blocks)
def qk_rmsnorm_cute(
input: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-6,
weight_bias: float = 0.0,
) -> None:
"""CuTe DSL QKRMSNorm for 3D tensors [batch, heads, head_dim].
Supports arbitrary stride - no need to call contiguous().
Each warp processes one (batch, head) pair independently using warp-only reduction.
Args:
input: Input tensor of shape [batch_size, num_heads, head_dim].
Last dimension must be contiguous (stride[-1] == 1).
weight: Weight tensor of shape [head_dim].
output: Output tensor (same shape as input).
eps: Small constant for numerical stability.
weight_bias: Bias added to weight (0 for standard RMSNorm, 1 for Gemma).
"""
assert input.dim() == 3, "QKRMSNorm expects 3D input [batch, heads, head_dim]"
batch_size, num_heads, head_dim = input.shape
M = batch_size * num_heads
# Kernel configuration
num_warps = 4
# Calculate grid size based on SM count and estimated occupancy
num_sms = get_num_sm(input.device)
blocks_per_sm = 16 # Theoretical max for 128-thread blocks
max_blocks = num_sms * blocks_per_sm
needed_blocks = (M + num_warps - 1) // num_warps
num_blocks = min(max_blocks, needed_blocks)
dtype_str = _torch_dtype_to_str(input.dtype)
kernel = _get_compiled_qk_rmsnorm_kernel(
dtype_str, head_dim, weight_bias, num_warps
)
# Pass 3D tensors directly - kernel handles arbitrary stride
kernel(input, weight, output, batch_size, num_heads, eps, num_blocks)
🧰 Tools
🪛 Ruff (0.14.14)

2050-2050: Unused function argument: enable_pdl

(ARG001)

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/norm.py` around lines 2044 - 2088, The qk_rmsnorm_cute
function accepts enable_pdl but never forwards it to the kernel compilation
(kernel created via _get_compiled_qk_rmsnorm_kernel uses a hardcoded value);
update qk_rmsnorm_cute to pass the enable_pdl flag into
_get_compiled_qk_rmsnorm_kernel (or else remove enable_pdl from
qk_rmsnorm_cute's signature) so the compiled kernel respects PDL support —
locate the _get_compiled_qk_rmsnorm_kernel call in qk_rmsnorm_cute and change
its arguments to include enable_pdl (and ensure any downstream kernel
invocation/signature matches this added parameter).

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #42732703: 1/20 passed

@yzh119
Copy link
Collaborator Author

yzh119 commented Jan 28, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !272 has been updated with latest changes, and the CI pipeline #42752005 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/norm.py`:
- Around line 1875-1882: The parameter enable_pdl in rmsnorm_cute is unused and
triggers ARG001; explicitly mark it as intentionally unused by adding a no-op
assignment (e.g., _ = enable_pdl) or a targeted noqa comment inside rmsnorm_cute
to show API-parity intent, and apply the same change to the other wrapper
functions mentioned in the review so each unused enable_pdl is acknowledged
rather than left unused.
- Around line 1875-2051: The public CuTe-DSL wrapper functions (rmsnorm_cute,
qk_rmsnorm_cute, rmsnorm_quant_cute, fused_add_rmsnorm_cute,
fused_add_rmsnorm_quant_cute, layernorm_cute) need the `@flashinfer_api` decorator
added and the decorator imported from the project’s standard utilities; add a
single import for flashinfer_api near other imports and prepend `@flashinfer_api`
above each of these function definitions so all public entry points are traced
for API-call logging (keep existing signatures and bodies unchanged).
- Around line 371-379: Rename the unused kernel parameter M to _M in the kernel
signatures to silence Ruff ARG002 (e.g., change the argument name in
cute_dsl.norm.LayerNormKernel.kernel and the other flagged kernels
RMSNormQuantKernel.kernel, FusedAddRMSNormKernel.kernel,
FusedAddRMSNormQuantKernel.kernel); update the parameter name only in the
function signature (or alternatively add a targeted "# noqa: ARG002" comment) so
the intent is clear and linters stop reporting the unused argument.
- Around line 1188-1200: The scalar FP8 store computes out_offset assuming
row-major contiguous layout (out_offset = bidx * H + idx), which fails for
non-contiguous mY; update the store in the block that calls
cvt_and_store_f32_to_e4m3/get_ptr_as_int64 to compute the correct linear offset
using the output tensor's stride (e.g., out_offset = bidx * mY.stride[0] + idx)
or mirror the non-quantized kernels by using CuTe's local_tile/partition_D logic
(as in FusedAddRMSNormKernel) to derive the physical address; ensure you
reference mY.stride and preserve idx calculation so cvt_and_store_f32_to_e4m3
receives the correct out_ptr for any layout.
- Around line 835-847: The FP8 store currently computes out_offset as bidx * H +
idx which assumes a contiguous row stride; update the offset calculation to use
the actual row stride (sym_row_stride_y) so stores respect arbitrary output
tensor strides—replace the use of H in out_offset with sym_row_stride_y (i.e.,
compute out_offset = bidx * sym_row_stride_y + idx) in the block that calls
get_ptr_as_int64(mY, Int32(out_offset)) and cvt_and_store_f32_to_e4m3; ensure
any alternative tiled layout approach mirrors how inputs are handled so the
store remains stride-aware.

Comment on lines +371 to +379
def kernel(
self,
mX: cute.Tensor,
mW: cute.Tensor,
mY: cute.Tensor,
M: Int32,
eps: Float32,
tv_layout: cute.Layout,
tiler_mn: cute.Shape,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Silence unused M kernel args to keep Ruff clean.

Ruff reports ARG002 for M in kernel signatures. Since M is not used inside kernels, rename it to _M (or add a targeted # noqa: ARG002) to document intent and satisfy lint. Apply the same pattern to the other kernel methods flagged by Ruff (RMSNormQuantKernel.kernel, FusedAddRMSNormKernel.kernel, FusedAddRMSNormQuantKernel.kernel, LayerNormKernel.kernel).

♻️ Example fix (apply similarly to other kernels)
-        M: Int32,
+        _M: Int32,
🧰 Tools
🪛 Ruff (0.14.14)

376-376: Unused method argument: M

(ARG002)

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/norm.py` around lines 371 - 379, Rename the unused kernel
parameter M to _M in the kernel signatures to silence Ruff ARG002 (e.g., change
the argument name in cute_dsl.norm.LayerNormKernel.kernel and the other flagged
kernels RMSNormQuantKernel.kernel, FusedAddRMSNormKernel.kernel,
FusedAddRMSNormQuantKernel.kernel); update the parameter name only in the
function signature (or alternatively add a targeted "# noqa: ARG002" comment) so
the intent is clear and linters stop reporting the unused argument.

Comment on lines +835 to +847
col_offset = tidx * vec_size
for v in cutlass.range_constexpr(num_vec_blocks):
for e in cutlass.range_constexpr(vec_size):
idx = col_offset + v * threads_per_row * vec_size + e
if idx < H:
# Clamp and convert - use flat index for register tensor
flat_idx = v * vec_size + e
clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX))
clamped = min(clamped, Float32(FLOAT8_E4M3_MAX))
# Use PTX to convert and store FP8 byte
out_offset = bidx * H + idx
out_ptr = get_ptr_as_int64(mY, Int32(out_offset))
cvt_and_store_f32_to_e4m3(clamped, out_ptr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cd /tmp/repo && find . -name "norm.py" -path "*/cute_dsl/*" | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 127


🏁 Script executed:

cd /tmp/repo && wc -l flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 127


🏁 Script executed:

cd /tmp/repo && sed -n '800,870p' flashinfer/cute_dsl/norm.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 127


🏁 Script executed:

cd /tmp/repo && sed -n '700,850p' flashinfer/cute_dsl/norm.py | grep -n "def \|sym_row_stride_y\|out_offset\|stride"

Repository: flashinfer-ai/flashinfer

Length of output: 127


🏁 Script executed:

cd /tmp/repo && rg "sym_row_stride_y" flashinfer/cute_dsl/norm.py -B 3 -A 3

Repository: flashinfer-ai/flashinfer

Length of output: 127


🏁 Script executed:

find . -name "norm.py" -path "*/cute_dsl/*"

Repository: flashinfer-ai/flashinfer

Length of output: 97


🏁 Script executed:

wc -l flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 100


🏁 Script executed:

sed -n '800,870p' flashinfer/cute_dsl/norm.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 3324


🏁 Script executed:

rg "sym_row_stride_y" flashinfer/cute_dsl/norm.py -B 5 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 2900


🏁 Script executed:

sed -n '600,750p' flashinfer/cute_dsl/norm.py | grep -n "class RMSNormQuantKernel\|def __init__\|sym_row_stride_y"

Repository: flashinfer-ai/flashinfer

Length of output: 118


🏁 Script executed:

sed -n '1,100p' flashinfer/cute_dsl/norm.py | grep -n "class RMSNormQuantKernel"

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

rg "class RMSNormQuantKernel" flashinfer/cute_dsl/norm.py -A 50

Repository: flashinfer-ai/flashinfer

Length of output: 1741


🏁 Script executed:

sed -n '93,300p' flashinfer/cute_dsl/norm.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 8375


🏁 Script executed:

rg "def __call__" flashinfer/cute_dsl/norm.py -A 100 | head -120

Repository: flashinfer-ai/flashinfer

Length of output: 4016


🏁 Script executed:

sed -n '93,500p' flashinfer/cute_dsl/norm.py | grep -A 80 "class RMSNormQuantKernel"

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

sed -n '135,250p' flashinfer/cute_dsl/norm.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 4708


🏁 Script executed:

sed -n '750,950p' flashinfer/cute_dsl/norm.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 8235


🏁 Script executed:

rg "stride\|\.strides" flashinfer/cute_dsl/norm.py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

sed -n '1200,1400p' flashinfer/cute_dsl/norm.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 7964


🏁 Script executed:

rg "get_ptr_as_int64\|out_ptr\|iterator" flashinfer/cute_dsl/norm.py -B 3 -A 3

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

grep -n "@dsl_user_op" flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 115


🏁 Script executed:

sed -n '1,100p' flashinfer/cute_dsl/norm.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 3944


🏁 Script executed:

sed -n '99,110p' flashinfer/cute_dsl/norm.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 659


🏁 Script executed:

rg "\.strides\|\.shape\|iterator" flashinfer/cute_dsl/ -l

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

sed -n '735,760p' flashinfer/cute_dsl/norm.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 955


🏁 Script executed:

rg "mY\.|mX\." flashinfer/cute_dsl/norm.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1174


🏁 Script executed:

sed -n '730,850p' flashinfer/cute_dsl/norm.py | grep -n "mY\|stride\|out_offset"

Repository: flashinfer-ai/flashinfer

Length of output: 445


🏁 Script executed:

sed -n '1600,1750p' flashinfer/cute_dsl/norm.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 5255


🏁 Script executed:

rg "sym_row_stride_y\|sym_row_stride_x" flashinfer/cute_dsl/norm.py -B 2 -A 10 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

sed -n '760,800p' flashinfer/cute_dsl/norm.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 1635


🏁 Script executed:

rg "def kernel" flashinfer/cute_dsl/norm.py -A 15 | grep -A 15 "RMSNormQuantKernel" | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

rg "\.strides\|\.stride\(" flashinfer/ -t py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

rg "mY\[" flashinfer/cute_dsl/norm.py | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

sed -n '1600,1620p' flashinfer/cute_dsl/norm.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 845


🏁 Script executed:

rg "mY\.shape\|mX\.shape" flashinfer/cute_dsl/norm.py -B 1 -A 1

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

sed -n '1550,1570p' flashinfer/cute_dsl/norm.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 649


🏁 Script executed:

sed -n '53,104p' flashinfer/cute_dsl/norm.py | grep -A 50 "_get_compiled_rmsnorm_quant_kernel"

Repository: flashinfer-ai/flashinfer

Length of output: 50


Use stride-aware offset calculation for FP8 output store.

Line 96: out_offset = bidx * H + idx assumes contiguous row stride equal to H, which breaks for arbitrary-stride outputs declared in the tensor layout (stride = sym_row_stride_y). Replace with out_offset = bidx * sym_row_stride_y + idx, or apply consistent tiled layout to the output tensor (as done for input) to automatically respect strides.

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/norm.py` around lines 835 - 847, The FP8 store currently
computes out_offset as bidx * H + idx which assumes a contiguous row stride;
update the offset calculation to use the actual row stride (sym_row_stride_y) so
stores respect arbitrary output tensor strides—replace the use of H in
out_offset with sym_row_stride_y (i.e., compute out_offset = bidx *
sym_row_stride_y + idx) in the block that calls get_ptr_as_int64(mY,
Int32(out_offset)) and cvt_and_store_f32_to_e4m3; ensure any alternative tiled
layout approach mirrors how inputs are handled so the store remains
stride-aware.

Comment on lines +1188 to +1200
col_offset = tidx * vec_size
for v in cutlass.range_constexpr(num_vec_blocks):
for e in cutlass.range_constexpr(vec_size):
idx = col_offset + v * threads_per_row * vec_size + e
if idx < H:
# Clamp and convert - use flat index for register tensor
flat_idx = v * vec_size + e
clamped = max(tYrY_f32[flat_idx], Float32(-FLOAT8_E4M3_MAX))
clamped = min(clamped, Float32(FLOAT8_E4M3_MAX))
# Use PTX to convert and store FP8 byte
out_offset = bidx * H + idx
out_ptr = get_ptr_as_int64(mY, Int32(out_offset))
cvt_and_store_f32_to_e4m3(clamped, out_ptr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

wc -l flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 100


🏁 Script executed:

sed -n '1170,1220p' flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 2030


🏁 Script executed:

sed -n '1100,1170p' flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 2690


🏁 Script executed:

sed -n '1050,1120p' flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 2074


🏁 Script executed:

grep -n "def.*mY" flashinfer/cute_dsl/norm.py | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

sed -n '1000,1050p' flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 1671


🏁 Script executed:

sed -n '1120,1160p' flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 1691


🏁 Script executed:

grep -n "gY\|mY" flashinfer/cute_dsl/norm.py | grep -A5 -B5 "1188"

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

sed -n '1155,1210p' flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 2403


🏁 Script executed:

grep -n "get_ptr_as_int64" flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 309


🏁 Script executed:

grep -rn "def get_ptr_as_int64" flashinfer/

Repository: flashinfer-ai/flashinfer

Length of output: 315


🏁 Script executed:

grep -n "out_offset\|output.*offset" flashinfer/cute_dsl/norm.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 321


🏁 Script executed:

sed -n '100,120p' flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 985


🏁 Script executed:

sed -n '840,860p' flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 933


🏁 Script executed:

sed -n '1060,1080p' flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 639


🏁 Script executed:

grep -B20 "def __call__" flashinfer/cute_dsl/norm.py | grep -A20 "FusedAddRMSNormQuantKernel"

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

grep -n "cute.local_tile\|partition_D" flashinfer/cute_dsl/norm.py | head -15

Repository: flashinfer-ai/flashinfer

Length of output: 847


🏁 Script executed:

sed -n '1120,1135p' flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 644


🏁 Script executed:

sed -n '800,850p' flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 2323


🏁 Script executed:

sed -n '400,430p' flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 1343


🏁 Script executed:

sed -n '960,1000p' flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 1683


🏁 Script executed:

grep -A15 "class FusedAddRMSNormKernel" flashinfer/cute_dsl/norm.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 465


🏁 Script executed:

grep -rn "FusedAddRMSNormQuantKernel\|RMSNormQuantKernel" flashinfer/ --include="*.py" | grep -v "class\|def\|#" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 720


🏁 Script executed:

grep -n "tensor.*stride\|stride.*tensor" flashinfer/cute_dsl/norm.py

Repository: flashinfer-ai/flashinfer

Length of output: 669


🏁 Script executed:

grep -rn "def get_ptr_as_int64" flashinfer/cute_dsl/ -A10

Repository: flashinfer-ai/flashinfer

Length of output: 1828


The FP8 scalar store path assumes row-major contiguous output layout.

The hardcoded out_offset = bidx * H + idx breaks non-contiguous outputs. Use CuTe's local_tile and partition_D like the non-quantized kernels (e.g., FusedAddRMSNormKernel), or query the output tensor's stride and compute out_offset = bidx * mY.stride[0] + idx.

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/norm.py` around lines 1188 - 1200, The scalar FP8 store
computes out_offset assuming row-major contiguous layout (out_offset = bidx * H
+ idx), which fails for non-contiguous mY; update the store in the block that
calls cvt_and_store_f32_to_e4m3/get_ptr_as_int64 to compute the correct linear
offset using the output tensor's stride (e.g., out_offset = bidx * mY.stride[0]
+ idx) or mirror the non-quantized kernels by using CuTe's
local_tile/partition_D logic (as in FusedAddRMSNormKernel) to derive the
physical address; ensure you reference mY.stride and preserve idx calculation so
cvt_and_store_f32_to_e4m3 receives the correct out_ptr for any layout.

Comment on lines +1875 to +1882
def rmsnorm_cute(
input: torch.Tensor,
weight: torch.Tensor,
out: torch.Tensor,
eps: float = 1e-6,
weight_bias: float = 0.0,
enable_pdl: bool = False,
) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

enable_pdl is unused in most wrappers.

Ruff flags ARG001 for these functions. If the parameter is only for API parity, make the intent explicit (e.g., _ = enable_pdl or a targeted # noqa: ARG001). Otherwise, plumb it through once those kernels support PDL.

✅ Example (apply similarly to other wrappers)
 def rmsnorm_cute(
     input: torch.Tensor,
     weight: torch.Tensor,
     out: torch.Tensor,
     eps: float = 1e-6,
     weight_bias: float = 0.0,
     enable_pdl: bool = False,
 ) -> None:
+    _ = enable_pdl  # reserved for future PDL support
     """CuTe DSL RMSNorm implementation.

Also applies to: 1949-1957, 1975-1982, 1997-2006

🧰 Tools
🪛 Ruff (0.14.14)

1881-1881: Unused function argument: enable_pdl

(ARG001)

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/norm.py` around lines 1875 - 1882, The parameter
enable_pdl in rmsnorm_cute is unused and triggers ARG001; explicitly mark it as
intentionally unused by adding a no-op assignment (e.g., _ = enable_pdl) or a
targeted noqa comment inside rmsnorm_cute to show API-parity intent, and apply
the same change to the other wrapper functions mentioned in the review so each
unused enable_pdl is acknowledged rather than left unused.

Comment on lines +1875 to +2051
def rmsnorm_cute(
input: torch.Tensor,
weight: torch.Tensor,
out: torch.Tensor,
eps: float = 1e-6,
weight_bias: float = 0.0,
enable_pdl: bool = False,
) -> None:
"""CuTe DSL RMSNorm implementation.

Supports arbitrary stride - no need to call contiguous().
Last dimension must be contiguous (stride[-1] == 1).
"""
H = input.shape[-1]
if input.dim() == 3:
M = input.shape[0] * input.shape[1]
input_2d = input.view(M, H)
out_2d = out.view(M, H)
else:
M = input.shape[0]
input_2d = input
out_2d = out

dtype_str = _torch_dtype_to_str(input.dtype)
kernel = _get_compiled_rmsnorm_kernel(dtype_str, H, weight_bias)
kernel(input_2d, weight, out_2d, M, eps)


def qk_rmsnorm_cute(
input: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-6,
weight_bias: float = 0.0,
enable_pdl: bool = False,
) -> None:
"""CuTe DSL QKRMSNorm for 3D tensors [batch, heads, head_dim].

Supports arbitrary stride - no need to call contiguous().
Each warp processes one (batch, head) pair independently using warp-only reduction.

Args:
input: Input tensor of shape [batch_size, num_heads, head_dim].
Last dimension must be contiguous (stride[-1] == 1).
weight: Weight tensor of shape [head_dim].
output: Output tensor (same shape as input).
eps: Small constant for numerical stability.
weight_bias: Bias added to weight (0 for standard RMSNorm, 1 for Gemma).
enable_pdl: Enable Programmatic Dependent Launch for SM90+ GPUs.
"""
assert input.dim() == 3, "QKRMSNorm expects 3D input [batch, heads, head_dim]"

batch_size, num_heads, head_dim = input.shape
M = batch_size * num_heads

# Kernel configuration
num_warps = 4

# Calculate grid size based on SM count and estimated occupancy
num_sms = get_num_sm(input.device)
blocks_per_sm = 16 # Theoretical max for 128-thread blocks
max_blocks = num_sms * blocks_per_sm
needed_blocks = (M + num_warps - 1) // num_warps
num_blocks = min(max_blocks, needed_blocks)

dtype_str = _torch_dtype_to_str(input.dtype)
kernel = _get_compiled_qk_rmsnorm_kernel(
dtype_str, head_dim, weight_bias, num_warps
)

# Pass 3D tensors directly - kernel handles arbitrary stride
kernel(input, weight, output, batch_size, num_heads, eps, num_blocks)


def rmsnorm_quant_cute(
out: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: float,
eps: float = 1e-6,
weight_bias: float = 0.0,
enable_pdl: bool = False,
) -> None:
"""CuTe DSL RMSNorm + FP8 quantization implementation.

Supports arbitrary stride - no need to call contiguous().
Last dimension must be contiguous (stride[-1] == 1).
"""

H = input.shape[-1]
M = input.shape[0]

dtype_str = _torch_dtype_to_str(input.dtype)
out_dtype_str = _torch_dtype_to_str(out.dtype)
kernel = _get_compiled_rmsnorm_quant_kernel(
dtype_str, out_dtype_str, H, weight_bias
)
kernel(out, input, weight, M, scale, eps)


def fused_add_rmsnorm_cute(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
weight_bias: float = 0.0,
enable_pdl: bool = False,
) -> None:
"""CuTe DSL Fused Add + RMSNorm implementation.

Supports arbitrary stride - no need to call contiguous().
Last dimension must be contiguous (stride[-1] == 1).
"""

H = input.shape[-1]
M = input.shape[0]

dtype_str = _torch_dtype_to_str(input.dtype)
kernel = _get_compiled_fused_add_rmsnorm_kernel(dtype_str, H, weight_bias)
kernel(input, residual, weight, M, eps)


def fused_add_rmsnorm_quant_cute(
out: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: float,
eps: float = 1e-6,
weight_bias: float = 0.0,
enable_pdl: bool = False,
) -> None:
"""CuTe DSL Fused Add + RMSNorm + FP8 quantization implementation.

Supports arbitrary stride - no need to call contiguous().
Last dimension must be contiguous (stride[-1] == 1).
"""

H = input.shape[-1]
M = input.shape[0]

dtype_str = _torch_dtype_to_str(input.dtype)
out_dtype_str = _torch_dtype_to_str(out.dtype)
kernel = _get_compiled_fused_add_rmsnorm_quant_kernel(
dtype_str, out_dtype_str, H, weight_bias
)
kernel(
out,
input,
residual,
weight,
M,
scale,
eps,
)


def layernorm_cute(
out: torch.Tensor,
input: torch.Tensor,
gamma: torch.Tensor,
beta: torch.Tensor,
eps: float = 1e-6,
) -> None:
"""CuTe DSL LayerNorm implementation.

Supports arbitrary stride - no need to call contiguous().
Last dimension must be contiguous (stride[-1] == 1).
"""

H = input.shape[-1]
M = input.shape[0]

dtype_str = _torch_dtype_to_str(input.dtype)
gamma_dtype_str = _torch_dtype_to_str(gamma.dtype)
kernel = _get_compiled_layernorm_kernel(dtype_str, gamma_dtype_str, H)
kernel(out, input, gamma, beta, M, eps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add @flashinfer_api on public CuTe-DSL API wrappers.

The public Python entry points here (e.g., rmsnorm_cute, qk_rmsnorm_cute, rmsnorm_quant_cute, fused_add_rmsnorm_cute, fused_add_rmsnorm_quant_cute, layernorm_cute) should be decorated for API-call tracing. Please apply @flashinfer_api (and import it from the standard location used elsewhere in the repo).

As per coding guidelines: Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats).

🧰 Tools
🪛 Ruff (0.14.14)

1881-1881: Unused function argument: enable_pdl

(ARG001)


1909-1909: Unused function argument: enable_pdl

(ARG001)


1956-1956: Unused function argument: enable_pdl

(ARG001)


1981-1981: Unused function argument: enable_pdl

(ARG001)


2005-2005: Unused function argument: enable_pdl

(ARG001)

🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/norm.py` around lines 1875 - 2051, The public CuTe-DSL
wrapper functions (rmsnorm_cute, qk_rmsnorm_cute, rmsnorm_quant_cute,
fused_add_rmsnorm_cute, fused_add_rmsnorm_quant_cute, layernorm_cute) need the
`@flashinfer_api` decorator added and the decorator imported from the project’s
standard utilities; add a single import for flashinfer_api near other imports
and prepend `@flashinfer_api` above each of these function definitions so all
public entry points are traced for API-call logging (keep existing signatures
and bodies unchanged).

@bkryu
Copy link
Collaborator

bkryu commented Jan 28, 2026

Hi @yzh119, is this a translation of the existing norm kernels from CUDA to CuTe DSL? or is it a re-write for better performance?

Asking because I'm curious whether it resolves #2396

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@flashinfer/norm.py`:
- Around line 468-471: The unconditional imports of rmsnorm_fp4quant and
add_rmsnorm_fp4quant from .cute_dsl can raise when cutlass/cute_dsl isn't
available; wrap these imports in a try/except (or check
cute_dsl.is_cute_dsl_available()) and fall back to setting rmsnorm_fp4quant =
None and add_rmsnorm_fp4quant = None so the module can still import and the
existing _USE_CUDA_NORM/CUDA fallback logic can decide at runtime; update any
usage sites to handle None accordingly.

Comment on lines +468 to +471
# CuTe-DSL fused RMSNorm + FP4 Quantization kernels
# These require CuTe-DSL to be available and SM100+ (Blackwell) GPUs
try:
from .cute_dsl import rmsnorm_fp4quant, add_rmsnorm_fp4quant
except ImportError:
# CuTe-DSL not available
rmsnorm_fp4quant = None # type: ignore[misc,assignment]
add_rmsnorm_fp4quant = None # type: ignore[misc,assignment]
# These require SM100+ (Blackwell) GPUs
from .cute_dsl import rmsnorm_fp4quant as rmsnorm_fp4quant
from .cute_dsl import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's look at the actual file structure around lines 468-471 in flashinfer/norm.py
head -n 480 flashinfer/norm.py | tail -n 30

Repository: flashinfer-ai/flashinfer

Length of output: 914


🏁 Script executed:

# Also check if there's a _USE_CUDA_NORM conditional block and where these imports sit relative to it
rg -n "_USE_CUDA_NORM|SM100|Blackwell" flashinfer/norm.py -A 5 -B 5

Repository: flashinfer-ai/flashinfer

Length of output: 4119


🏁 Script executed:

# Search for existing error handling patterns with cute_dsl imports elsewhere
rg -n "cute_dsl|try.*import|except ImportError" flashinfer/ --type py -l

Repository: flashinfer-ai/flashinfer

Length of output: 439


🏁 Script executed:

# Check how the cute_dsl module is structured and whether it handles SM100+ checks
head -n 100 flashinfer/cute_dsl.py 2>/dev/null || find . -name "cute_dsl*" -type f

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check the cute_dsl __init__.py to see if there's SM100 version checking
cat -n flashinfer/cute_dsl/__init__.py

Repository: flashinfer-ai/flashinfer

Length of output: 3281


🏁 Script executed:

# Search for rmsnorm_fp4quant definition and any SM100 guards
rg -n "rmsnorm_fp4quant|SM100" flashinfer/cute_dsl/ --type py -B 3 -A 3

Repository: flashinfer-ai/flashinfer

Length of output: 11099


🏁 Script executed:

# Check if cute_dsl module raises errors or handles gracefully on non-SM100 systems
rg -n "raise|RuntimeError|ImportError|version|compute_capability" flashinfer/cute_dsl/ --type py

Repository: flashinfer-ai/flashinfer

Length of output: 6658


🏁 Script executed:

# Check how the module is used after import - are these functions wrapped with SM100 checks?
rg -n "rmsnorm_fp4quant\|add_rmsnorm_fp4quant" flashinfer/ --type py -B 2 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Find the is_cute_dsl_available function
rg -n "def is_cute_dsl_available" flashinfer/ --type py -A 10

Repository: flashinfer-ai/flashinfer

Length of output: 742


🏁 Script executed:

# Check the utils.py in cute_dsl to see what is_cute_dsl_available does
head -n 50 flashinfer/cute_dsl/utils.py

Repository: flashinfer-ai/flashinfer

Length of output: 1524


🏁 Script executed:

# Check how rmsnorm_fp4quant and add_rmsnorm_fp4quant are used after import
rg -n "rmsnorm_fp4quant|add_rmsnorm_fp4quant" flashinfer/norm.py -A 3 -B 3

Repository: flashinfer-ai/flashinfer

Length of output: 306


🏁 Script executed:

# Also check if there's any __getattr__ or dynamic import fallback in norm.py or at module level
rg -n "__getattr__|try:|except|__all__" flashinfer/norm.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 134


🏁 Script executed:

# Check flashinfer's main __init__.py to see if norm module import is conditional
head -n 100 flashinfer/__init__.py

Repository: flashinfer-ai/flashinfer

Length of output: 3782


Wrap FP4 imports in try/except to prevent import failures on systems without cutlass.

The rmsnorm_fp4quant and add_rmsnorm_fp4quant imports are unconditional, but cute_dsl/__init__.py only exports them when is_cute_dsl_available() returns True (i.e., when cutlass is installed). Without error handling, importing flashinfer.norm will fail on systems without cutlass, even though the module provides CUDA-based fallbacks via _USE_CUDA_NORM. This breaks backward compatibility.

Suggested pattern for graceful degradation
 # CuTe-DSL fused RMSNorm + FP4 Quantization kernels
 # These require SM100+ (Blackwell) GPUs
-from .cute_dsl import rmsnorm_fp4quant as rmsnorm_fp4quant
-from .cute_dsl import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant
+try:
+    from .cute_dsl import rmsnorm_fp4quant as rmsnorm_fp4quant
+    from .cute_dsl import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant
+except ImportError:
+    rmsnorm_fp4quant = None
+    add_rmsnorm_fp4quant = None
🤖 Prompt for AI Agents
In `@flashinfer/norm.py` around lines 468 - 471, The unconditional imports of
rmsnorm_fp4quant and add_rmsnorm_fp4quant from .cute_dsl can raise when
cutlass/cute_dsl isn't available; wrap these imports in a try/except (or check
cute_dsl.is_cute_dsl_available()) and fall back to setting rmsnorm_fp4quant =
None and add_rmsnorm_fp4quant = None so the module can still import and the
existing _USE_CUDA_NORM/CUDA fallback logic can decide at runtime; update any
usage sites to handle None accordingly.

@aleozlx
Copy link
Collaborator

aleozlx commented Jan 28, 2026

wanna ask you @yzh119 about the reason we put these things all in cute_dsl
would it be a good idea we put the kernels in the cute_dsl folder into different ops in the future? because i feel organizing them by ops feel more understandable and the language is a next level detail

@yzh119
Copy link
Collaborator Author

yzh119 commented Jan 28, 2026

would it be a good idea we put the kernels in the cute_dsl folder into different ops in the future? because i feel organizing them by ops feel more understandable and the language is a next level detail

We should categorize kernels by functionalities, not sources. All kernels inside cute_dsl should be refactored, we can do that in another PR.

For this specific PR, let me make norm a module and move cute-dsl code under this folder.

@bkryu
Copy link
Collaborator

bkryu commented Jan 28, 2026

cc @kahyunnam

@flashinfer-bot
Copy link
Collaborator

[CANCELED] Pipeline #42752005: canceled

@bkryu
Copy link
Collaborator

bkryu commented Mar 3, 2026

with reference to #2459, scale would need to be a tensor for it to be cuda graph compatible similar change would be required for rmsnorm_quant_cute also

@bkryu can this be addressed before merging as it blocks my sglang changes

Thanks @DevashishLal-CB. I absorbed #2459 in a newly added commit a7690ec by essentially bringing over all the changes. Is this what you were looking for?

@bkryu
Copy link
Collaborator

bkryu commented Mar 3, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !272 has been updated with latest changes, and the CI pipeline #45268078 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #45268078: canceled

@bkryu
Copy link
Collaborator

bkryu commented Mar 4, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !272 has been created, and the CI pipeline #45342593 is currently running. I'll report back once the pipeline job completes.

@DevashishLal-CB
Copy link
Contributor

with reference to #2459, scale would need to be a tensor for it to be cuda graph compatible similar change would be required for rmsnorm_quant_cute also

@bkryu can this be addressed before merging as it blocks my sglang changes

Thanks @DevashishLal-CB. I absorbed #2459 in a newly added commit a7690ec by essentially bringing over all the changes. Is this what you were looking for?

Yup thanks, LGTM!

@bkryu bkryu added the op: norm label Mar 4, 2026
@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #45342593: 10/20 passed

…iately. Eliminated live-but-unnecessary smem staging in sGamma_f32 and sBeta_f32 that are allocated in smem but never shared
Copy link
Collaborator

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link
Contributor

@jimmyzho jimmyzho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for API + module level, LGTM

@bkryu
Copy link
Collaborator

bkryu commented Mar 12, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !272 has been updated with latest changes, and the CI pipeline #45936280 is currently running. I'll report back once the pipeline job completes.

@bkryu
Copy link
Collaborator

bkryu commented Mar 12, 2026

/bot stop

@flashinfer-bot
Copy link
Collaborator

The GitLab CI pipeline #45936280 has been cancelled.

@bkryu
Copy link
Collaborator

bkryu commented Mar 12, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !272 has been updated with latest changes, and the CI pipeline #45942144 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #45942144: 10/20 passed

@bkryu bkryu merged commit 8bf921a into flashinfer-ai:main Mar 12, 2026
57 of 70 checks passed
frankwang28 pushed a commit to frankwang28/flashinfer that referenced this pull request Mar 18, 2026
)

<!-- .github/pull_request_template.md -->

## 📌 Description

We prioritize using dsl for kernel development over cuda for faster JIT
compilation speed.
This PR is the first series that refactors the simple normalization
kernels to cute-dsl.

CUDA code should be ready to remove after we finish end-to-end testing.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* CuTe-DSL–accelerated normalization: RMSNorm (2D/3D), LayerNorm, fused
add+RMSNorm, and FP8-quantized variants exposed for runtime use.
  * Shared norm utilities and JIT warmup to improve kernel readiness.

* **Chores**
* Runtime selection and fallback for CuTe-DSL/CUDA normalization with a
visibility check.

* **Bug Fixes**
* Safer optional-dependency handling to avoid hard failures when
CUDA/CuTe-DSL is unavailable.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Yaxing Cai <caiyaxing666@gmail.com>
Co-authored-by: Brian Ryu <bryu@nvidia.com>
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
)

<!-- .github/pull_request_template.md -->

## 📌 Description

We prioritize using dsl for kernel development over cuda for faster JIT
compilation speed.
This PR is the first series that refactors the simple normalization
kernels to cute-dsl.

CUDA code should be ready to remove after we finish end-to-end testing.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* CuTe-DSL–accelerated normalization: RMSNorm (2D/3D), LayerNorm, fused
add+RMSNorm, and FP8-quantized variants exposed for runtime use.
  * Shared norm utilities and JIT warmup to improve kernel readiness.

* **Chores**
* Runtime selection and fallback for CuTe-DSL/CUDA normalization with a
visibility check.

* **Bug Fixes**
* Safer optional-dependency handling to avoid hard failures when
CUDA/CuTe-DSL is unavailable.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Yaxing Cai <caiyaxing666@gmail.com>
Co-authored-by: Brian Ryu <bryu@nvidia.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
aleozlx added a commit that referenced this pull request Mar 24, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

fix api breaking changes for 0.6.7 release

## 🔍 Related Issues (Gated-by PRs)


https://github.com/flashinfer-ai/flashinfer/issues?q=state%3Aopen%20label%3Av0.6.7

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

**API changes review**

API changes since v0.6.6

  PR #2520 + commit e35c19e (fixed to be compatible)

  Function: xqa()
Change: Added k_sf_cache=None, v_sf_cache=None as keyword-only params
(after *). Backward-compatible.

  PR #2618 (has PR #2730 to fix it)

  Function: gated_delta_rule_mtp()
Change: disable_state_update: bool = True → Optional[bool] = None. Still
defaults to True at runtime but emits a deprecation
  warning; will flip to False in 0.7.0.

  PR #2775 (expected — cute DSL MoE cleanup)

  Function: blockscaled_contiguous_grouped_gemm_nvfp4()
  Change: Entire @flashinfer_api decorated function deleted.

  Function: blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4()
  Change: Entire @flashinfer_api decorated function deleted.

Function:
blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4()
Change: @flashinfer_api decorator removed; added enable_pdl: bool = True
param.

  Function: blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4()
Change: @flashinfer_api decorator removed; added enable_pdl: bool = True
param.

  Function: CuteDslMoEWrapper.__init__()
  Change: Added enable_pdl: bool = True param. Backward-compatible.

  Function: cute_dsl_fused_moe_nvfp4()
  Change: Added enable_pdl: bool = True param. Backward-compatible.

  PR #2428

  Function: rmsnorm_quant()
Change: scale: float → scale: Union[float, torch.Tensor]; return type
torch.Tensor → None.

  Function: fused_add_rmsnorm_quant()
  Change: scale: float → scale: Union[float, torch.Tensor].

  Quantization functions (relocated, not removed)

All quantization APIs (fp4_quantize, block_scale_interleave,
e2m1_and_ufp8sf_scale_to_float, shuffle_matrix_a, shuffle_matrix_sf_a,
nvfp4_quantize, nvfp4_batched_quantize, scaled_fp4_grouped_quantize,
mxfp4_quantize, mxfp4_dequantize, mxfp4_dequantize_host,
mxfp8_quantize, mxfp8_dequantize_host) were moved from
flashinfer/fp4_quantization.py and flashinfer/fp8_quantization.py to
flashinfer/quantization/. Signatures, @flashinfer_api decorators, and
__init__.py exports are preserved. No breakage.

```diff
$ git diff v0.6.6 | grep -A20 "@flashinfer_api"                                               
     @flashinfer_api
@@ -1215,6 +1227,9 @@ class BatchDecodeWithPagedKVCacheWrapper:
         sinks: Optional[torch.Tensor] = None,
         q_len_per_req: Optional[int] = 1,
         skip_softmax_threshold_scale_factor: Optional[float] = None,
+        kv_block_scales: Optional[
+            Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
+        ] = None,
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         r"""Compute batch decode attention between query and paged kv cache.

@@ -1273,6 +1288,15 @@ class BatchDecodeWithPagedKVCacheWrapper:
             enable_pdl = device_support_pdl(q.device)
         k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout)

+        # Unpack kv_block_scales
+        key_block_scales = None
+        value_block_scales = None
+        if kv_block_scales is not None:
+            if isinstance(kv_block_scales, tuple):
+                key_block_scales, value_block_scales = kv_block_scales
--
-@flashinfer_api
-def fp4_quantize(
-    input: torch.Tensor,
-    global_scale: Optional[torch.Tensor] = None,
-    sf_vec_size: int = 16,
-    sf_use_ue8m0: bool = False,
-    is_sf_swizzled_layout: bool = True,
-    is_sf_8x4_layout: bool = False,
-    enable_pdl: Optional[bool] = None,
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    """Quantize input tensor to FP4 format.
-
-    This function implements FP4 quantization that converts input tensors to a compressed FP4 format
-    with associated scale factors. It supports various input data types and scale factor layouts.
-
-    Args:
-        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
-        global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-        sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False.
-        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
--
-@flashinfer_api
-def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor:
-    """Swizzle block scale tensor for FP4 format.
-
-    This function swizzles the block scale tensor to optimize memory access patterns
-    for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128.
-
-    Args:
-        unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16.
-
-    Returns:
-        torch.Tensor: Swizzled tensor with the same shape as input.
-
-    Raises:
-        AssertionError: If input dtype is not uint8 or bfloat16.
-    """
-    # TODO(shuw): check input dtype is uint8
-    assert (
-        unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16
-    ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}"
-
--
-@flashinfer_api
-def e2m1_and_ufp8sf_scale_to_float(
-    e2m1_tensor: torch.Tensor,
-    ufp8_scale_tensor: torch.Tensor,
-    global_scale_tensor: Optional[torch.Tensor] = None,
-    sf_vec_size: int = 16,
-    ufp8_type: int = 1,
-    is_sf_swizzled_layout: bool = True,
-) -> torch.Tensor:
-    """Convert E2M1 format tensor and UFP8 scale factors to float tensor.
-
-    This function performs dequantization by converting a packed FP4 tensor in E2M1 format
-    back to float values using the associated UFP8 scale factors and global scale.
-
-    Args:
-        e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8.
-        ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8.
-        global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-        ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1.
-        is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
--
-@flashinfer_api
-def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor:
-    """
-    PyTorch equivalent of trtllm-gen `shuffleMatrixA`
-    """
-    row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m)
-
-    return input_tensor[row_indices.to(input_tensor.device)]
-
-
-@flashinfer_api
-def shuffle_matrix_sf_a(
-    input_tensor: torch.Tensor,
-    epilogue_tile_m: int,
-    num_elts_per_sf: int = 16,
-):
-    """
-    Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat.
-    `shuffleMatrixSfA` expects the input to be in 128x4 layout and then
-    apply the same shuffling in `shuffleMatrixA` and writes out in 128x4
-    layout.
-    This function expects the input to be in linear layout. It's done this
-    way because the scaling factors in the NVFP4 checkpoints are quantized
-    and are in linear layout.
-    This function doesn't add padding.
-    """
-
-    row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m)
-
-    w_shuffled = input_tensor[row_indices.to(input_tensor.device)]
-
--
-@flashinfer_api
-def nvfp4_quantize(
-    a,
-    a_global_sf,
-    sfLayout=SfLayout.layout_128x4,
-    do_shuffle=False,
-    sf_vec_size=16,
-    enable_pdl=None,
-):
-    """
-    Quantize input tensor to NVFP4 format.
-
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
-        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
-        sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4.
-        do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
-            If None, automatically detects based on device capability. Defaults to None.
-
--
-@flashinfer_api
-def mxfp4_quantize(a):
-    """
-    Quantize input tensor to MXFP4 format.
-
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
-
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
-            - Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
-    """
-    a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max()
-    a_fp4, a_sf = fp4_quantize(a.cuda(), a_global_sf.cuda(), 32, True, True)
-    return a_fp4, a_sf
-
-
-@flashinfer_api
-def mxfp4_dequantize(a_fp4, a_sf):
-    """
-    Dequantize input tensor from MXFP4 format.
-
-    Parameters:
-        a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
-        a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
-
-    Returns:
-        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
-    """
-    return e2m1_and_ufp8sf_scale_to_float(
-        a_fp4.cpu().view(torch.uint8),
-        a_sf.cpu().view(torch.uint8).reshape(-1),
-        torch.tensor([1.0], device=a_fp4.device),
-        32,
-        0,
-        True,
-    )
-
--
-@flashinfer_api
-def mxfp4_dequantize_host(
-    weight: torch.Tensor,
-    scale: torch.Tensor,
-    group_size: int = 32,
-) -> torch.Tensor:
-    """
-    Dequantize input tensor from MXFP4 format on host.
-
-    Parameters:
-        weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
-        scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
-        group_size (int, optional): Group size for dequantization. Defaults to 32.
-
-    Returns:
-        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
-    """
-    # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future
-    major, minor = get_compute_capability(
-        torch.device("cuda:0")
-    )  # use any cuda device to get a compute capability
--
-@flashinfer_api
-def nvfp4_batched_quantize(
-    a,
-    a_global_sf,
-    sf_vec_size=16,
-):
-    """
-    Quantize batched input tensor to NVFP4 format.
-
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
-        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
-        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
-
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
-            - Scale factors tensor with shape determined by layout and sf_vec_size
-    """
-    major, minor = get_compute_capability(a.device)
-    device_arch = f"{major * 10 + minor}"
--
-@flashinfer_api
-def scaled_fp4_grouped_quantize(
-    a,
-    mask,
-    a_global_sf,
-):
-    """
-    quantize batched input tensor to NVFP4 format with mask.
-    Parameters:
-        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
-        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
-        mask (torch.Tensor): Mask tensor to apply before quantization.
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
-            - Scale factors tensor with shape determined by layout and sf_vec_size
-    """
-    major, minor = get_compute_capability(a.device)
-    device_arch = f"{major * 10 + minor}"
-    a_fp4, a_sf = get_fp4_quantization_module(
-        device_arch
--
-@flashinfer_api
-def mxfp8_quantize(
-    input: torch.Tensor,
-    is_sf_swizzled_layout: bool = True,
-    alignment: int = 32,
-    enable_pdl: Optional[bool] = None,
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    """Quantize input tensor to MxFP8 format.
-
-    This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format
-    with associated scale factors. It supports various input data types and scale factor layouts.
-
-    Args:
-        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
-        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
-        alignment (int, optional): sfVecSize. Defaults to 32.
-        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
-            If None, automatically detects based on device capability. Defaults to None.
-    Returns:
-        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
-            - Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3
--
-@flashinfer_api
-def mxfp8_dequantize_host(
-    input: torch.Tensor,
-    scale_tensor: torch.Tensor,
-    is_sf_swizzled_layout: bool = True,
-) -> torch.Tensor:
-    """Dequantize input tensor from MxFP8 format.
-
-    This function performs dequantization by converting a packed FP8 tensor in MxFP8 format
-    back to float values using the associated scale factors.
-
-    Args:
-        input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3.
-        scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size.
-        is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
-
-    Returns:
-        torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32.
-
-    """
-
--
-@flashinfer_api
 def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -323,6 +324,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
     vectorized_f32: bool = True,
     raster_along_m: bool = False,
     sm_count: Optional[int] = None,
+    enable_pdl: bool = True,
 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
     """Blockscaled Contiguous Gather Grouped GEMM with SwiGLU Fusion for MoE workloads.

@@ -423,7 +425,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
     major, minor = get_compute_capability(a.device)
     if major != 10:
         raise ValueError(
-            f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103, SM110). "
+            f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103). "
             f"Got SM{major}{minor}."
         )

--
-@flashinfer_api
-def blockscaled_contiguous_grouped_gemm_nvfp4(
-    a: torch.Tensor,
-    b: torch.Tensor,
-    a_scale: torch.Tensor,
-    b_scale: torch.Tensor,
-    alpha: torch.Tensor,
-    tile_idx_to_group_idx: torch.Tensor,
-    num_non_exiting_tiles: torch.Tensor,
-    out: Optional[torch.Tensor] = None,
-    *,
-    ab_dtype: str = "float4_e2m1fn",
-    sf_dtype: str = "float8_e4m3fn",
-    c_dtype: str = "bfloat16",
-    sf_vec_size: int = 16,
-    mma_tiler_mn: Tuple[int, int] = (128, 128),
-    cluster_shape_mn: Tuple[int, int] = (1, 1),
-    sm_count: Optional[int] = None,
-) -> torch.Tensor:
-    """Blockscaled Contiguous Grouped GEMM for MoE workloads with NVFP4 quantization.
-
--
-@flashinfer_api
 def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
     a: torch.Tensor,
     b: torch.Tensor,
@@ -272,6 +279,7 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
     cluster_shape_mn: Tuple[int, int] = (2, 1),
     raster_along_m: bool = False,
     sm_count: Optional[int] = None,
+    enable_pdl: bool = True,
 ) -> torch.Tensor:
     """Blockscaled Contiguous Grouped GEMM with Finalize Fusion for MoE workloads.

@@ -298,7 +306,11 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4(
             expanded_idx = token_idx * topk + topk_idx. Invalid rows have -1.
         token_final_scales: Router scaling factors, shape (seq_len, topk), float32/bf16/fp16
         out: Optional output tensor, shape (seq_len, n). Created if None.
-             This tensor is used for atomic accumulation, so it should be zero-initialized.
+             This tensor is used for atomic accumulation. If `out` is
+             provided, it must already be zero-initialized by the caller.
+             If `out` is None, this function allocates a zero-initialized
+             output tensor. Passing a non-zeroed `out` buffer will silently
--
-@flashinfer_api
-def blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4(
-    a: torch.Tensor,
-    b: torch.Tensor,
-    a_scale: torch.Tensor,
-    b_scale: torch.Tensor,
-    alpha: torch.Tensor,
-    tile_idx_to_group_idx: torch.Tensor,
-    num_non_exiting_tiles: torch.Tensor,
-    out: Optional[torch.Tensor] = None,
-    out_scale: Optional[torch.Tensor] = None,
-    global_scale: Optional[torch.Tensor] = None,
-    *,
-    ab_dtype: str = "float4_e2m1fn",
-    sf_dtype: str = "float8_e4m3fn",
-    c_dtype: str = "bfloat16",
-    sf_vec_size: int = 16,
-    mma_tiler_mn: Tuple[int, int] = (256, 128),
-    cluster_shape_mn: Tuple[int, int] = (2, 1),
-    vectorized_f32: bool = True,
-    sm_count: Optional[int] = None,
--
     @flashinfer_api
     def __init__(
         self,
@@ -347,6 +355,7 @@ class CuteDslMoEWrapper:
         sf_vec_size: int = 16,
         output_dtype: torch.dtype = torch.bfloat16,
         device: str = "cuda",
+        enable_pdl: bool = True,
     ):
         """Initialize the MoE wrapper.

@@ -363,6 +372,7 @@ class CuteDslMoEWrapper:
             sf_vec_size: Scale factor vector size. Default: 16.
             output_dtype: Output data type. Default: torch.bfloat16.
             device: Device for buffer allocation. Default: "cuda".
+            enable_pdl: Enable Programmatic Dependent Launch. Default: True.
         """
         self.num_experts = num_experts
         self.top_k = top_k
@@ -376,6 +386,7 @@ class CuteDslMoEWrapper:
         self.sf_vec_size = sf_vec_size
--
     @flashinfer_api
@@ -550,9 +570,10 @@ class CuteDslMoEWrapper:
                 f"num_tokens ({num_tokens}) exceeds max_num_tokens ({self.max_num_tokens})"
             )

-        # Allocate output buffer if not using pre-allocated one
+        # Slice the pre-allocated buffer to the active batch so that
+        # _moe_core_impl only zeros num_tokens rows, not max_num_tokens.
         if self.use_cuda_graph:
-            moe_output = self._moe_output
+            moe_output = self._moe_output[:num_tokens]
         else:
             moe_output = torch.empty(
                 (num_tokens, self.hidden_size),
@@ -627,6 +648,7 @@ def _cute_dsl_fused_moe_nvfp4_impl(
     use_fused_finalize: bool = True,
     moe_output: Optional[torch.Tensor] = None,
     aux_stream: Optional[torch.cuda.Stream] = None,
+    enable_pdl: bool = True,
 ) -> torch.Tensor:
     """Internal implementation called by auto-tuner for functional API."""
--
 @flashinfer_api
 def cute_dsl_fused_moe_nvfp4(
     x: torch.Tensor,
@@ -678,9 +702,12 @@ def cute_dsl_fused_moe_nvfp4(
     use_fused_finalize: bool = True,
     moe_output: Optional[torch.Tensor] = None,
     aux_stream: Optional[torch.cuda.Stream] = None,
+    enable_pdl: bool = True,
 ) -> torch.Tensor:
     """Run fused MoE computation using CuteDSL NVFP4 kernels.

+    Supported architectures: SM100, SM103.
+
     This is the simple functional API. For CUDA graph support, use
     `CuteDslMoEWrapper` instead.

@@ -736,6 +763,7 @@ def cute_dsl_fused_moe_nvfp4(
         local_expert_offset=local_expert_offset,
         use_fused_finalize=use_fused_finalize,
         output_dtype=output_dtype,
+        enable_pdl=enable_pdl,
--
 @flashinfer_api
 def gated_delta_rule_decode_pretranspose(
     q: torch.Tensor,
@@ -1002,8 +174,9 @@ def gated_delta_rule_decode_pretranspose(
         - State layout is v-major (K-last): [B, HV, V, K]. When state is bfloat16
           and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used
           (supports both the direct ``state`` path and the pool+indices path).
-        - pool+indices (``initial_state``/``initial_state_indices``) only supported
-          via the bf16 fast path; float32 state raises an error.
+        - pool+indices (``initial_state``/``initial_state_indices``) supported on
+          both the bf16 fast path (T in 1..4, K=V=128) and the float32 legacy path
+          (T=1). The float32 path also supports negative indices for padding.
         - Legacy path (float32 state, T=1): K and V must be multiples of 4.
     """
     # Validate input shapes
@@ -1069,13 +242,17 @@ def gated_delta_rule_decode_pretranspose(
         return_state = initial_state if use_pool else state
         return output, return_state

-    # Legacy path: T=1 only, float32 state (no pool+indices support)
-    assert not use_pool, (
--
 @flashinfer_api
 def gated_delta_rule_mtp(
     q: torch.Tensor,
@@ -2427,7 +489,7 @@ def gated_delta_rule_mtp(
     scale: Optional[float] = None,
     output: Optional[torch.Tensor] = None,
     intermediate_states_buffer: Optional[torch.Tensor] = None,
-    disable_state_update: bool = True,
+    disable_state_update: Optional[bool] = None,
     use_qk_l2norm: bool = True,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     """
@@ -2463,8 +525,15 @@ def gated_delta_rule_mtp(
         intermediate_states_buffer (Optional[torch.Tensor]):
             Buffer for caching intermediate states, shape ``[pool_size, T, HV, V, K]``.
             If None, intermediate states are not cached.
-        disable_state_update (bool):
-            If True, the initial state is not updated. Default: ``True``.
+        disable_state_update (Optional[bool]):
+            If True, the initial state is not updated. Currently defaults to ``True``.
+            Please pass this argument explicitly — the default will change to ``False``
--
 @flashinfer_api
@@ -60,16 +120,14 @@ def rmsnorm(
     output: torch.Tensor
         Normalized tensor, 2D shape (batch_size, hidden_size) or 3D shape (batch_size, num_heads, hidden_size).
     """
-    if enable_pdl is None:
-        enable_pdl = device_support_pdl(input.device)
     if out is None:
         out = torch.empty_like(input)
-    _rmsnorm(out, input, weight, eps, enable_pdl)
+    _rmsnorm_impl(out, input, weight, eps, enable_pdl)
     return out


 @register_custom_op("flashinfer::rmsnorm", mutates_args=("out",))
-def _rmsnorm(
+def _rmsnorm_impl(
     out: torch.Tensor,
     input: torch.Tensor,
     weight: torch.Tensor,
@@ -78,11 +136,21 @@ def _rmsnorm(
--
 @flashinfer_api
 def fmha_v2_prefill_deepseek(
     query: torch.Tensor,
@@ -3865,18 +4029,11 @@ def fmha_v2_prefill_deepseek(
         If return_lse is False, the output will be a single tensor.
     """
     if not is_sm12x_supported(query.device):
-        major, minor = get_compute_capability(query.device)
-        if major == 12:
-            min_cuda = "13.0" if minor >= 1 else "12.8"
-            raise ValueError(
-                f"fmha_v2_prefill_deepseek requires CUDA >= {min_cuda} "
-                f"for SM12{minor}x GPUs."
-            )
         raise ValueError("fmha_v2_prefill_deepseek is only supported on SM12x GPUs.")
     assert query.shape[3] == 192 and key.shape[3] == 192 and value.shape[3] == 128, (
         "currently only support deepseek r1 192 query and 128 value"
     )
-    module = get_trtllm_fmha_v2_module()
+    module = get_trtllm_fmha_v2_sm120_module()
     is_e4m3 = query.dtype == torch.float8_e4m3fn
--
+@flashinfer_api
+def trtllm_fmha_v2_prefill(
+    qkv: Union[
+        torch.Tensor,
+        Tuple[torch.Tensor, torch.Tensor],
+        Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+    ],
+    input_layout: str,
+    workspace_buffer: torch.Tensor,
+    seq_lens: torch.Tensor,
+    max_q_len: int,
+    max_kv_len: int,
+    bmm1_scale: float,
+    bmm2_scale: float,
+    batch_size: int,
+    cum_seq_lens_q: torch.Tensor,
+    cum_seq_lens_kv: torch.Tensor,
+    block_tables: Optional[torch.Tensor] = None,
+    out: Optional[torch.Tensor] = None,
+    out_dtype: Optional[Union[torch.dtype, str]] = None,
+    sinks: Optional[List[torch.Tensor]] = None,
--
+@flashinfer_api
+def fp4_quantize(
+    input: torch.Tensor,
+    global_scale: Optional[torch.Tensor] = None,
+    sf_vec_size: int = 16,
+    sf_use_ue8m0: bool = False,
+    is_sf_swizzled_layout: bool = True,
+    is_sf_8x4_layout: bool = False,
+    enable_pdl: Optional[bool] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Quantize input tensor to FP4 format.
+
+    This function implements FP4 quantization that converts input tensors to a compressed FP4 format
+    with associated scale factors. It supports various input data types and scale factor layouts.
+
+    Args:
+        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
+        global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+        sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False.
+        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
--
+@flashinfer_api
+def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor:
+    """Swizzle block scale tensor for FP4 format.
+
+    This function swizzles the block scale tensor to optimize memory access patterns
+    for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128.
+
+    Args:
+        unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16.
+
+    Returns:
+        torch.Tensor: Swizzled tensor with the same shape as input.
+
+    Raises:
+        AssertionError: If input dtype is not uint8 or bfloat16.
+    """
+    # TODO(shuw): check input dtype is uint8
+    assert (
+        unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16
+    ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}"
+
--
+@flashinfer_api
+def e2m1_and_ufp8sf_scale_to_float(
+    e2m1_tensor: torch.Tensor,
+    ufp8_scale_tensor: torch.Tensor,
+    global_scale_tensor: Optional[torch.Tensor] = None,
+    sf_vec_size: int = 16,
+    ufp8_type: int = 1,
+    is_sf_swizzled_layout: bool = True,
+) -> torch.Tensor:
+    """Convert E2M1 format tensor and UFP8 scale factors to float tensor.
+
+    This function performs dequantization by converting a packed FP4 tensor in E2M1 format
+    back to float values using the associated UFP8 scale factors and global scale.
+
+    Args:
+        e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8.
+        ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8.
+        global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+        ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1.
+        is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True.
--
+@flashinfer_api
+def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor:
+    """
+    PyTorch equivalent of trtllm-gen `shuffleMatrixA`
+    """
+    row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m)
+
+    return input_tensor[row_indices.to(input_tensor.device)]
+
+
+@flashinfer_api
+def shuffle_matrix_sf_a(
+    input_tensor: torch.Tensor,
+    epilogue_tile_m: int,
+    num_elts_per_sf: int = 16,
+):
+    """
+    Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat.
+    `shuffleMatrixSfA` expects the input to be in 128x4 layout and then
+    apply the same shuffling in `shuffleMatrixA` and writes out in 128x4
+    layout.
+    This function expects the input to be in linear layout. It's done this
+    way because the scaling factors in the NVFP4 checkpoints are quantized
+    and are in linear layout.
+    This function doesn't add padding.
+    """
+
+    row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m)
+
+    w_shuffled = input_tensor[row_indices.to(input_tensor.device)]
+
--
+@flashinfer_api
+def nvfp4_quantize(
+    a,
+    a_global_sf,
+    sfLayout=SfLayout.layout_128x4,
+    do_shuffle=False,
+    sf_vec_size=16,
+    enable_pdl=None,
+):
+    """
+    Quantize input tensor to NVFP4 format.
+
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
+        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
+        sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4.
+        do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
+            If None, automatically detects based on device capability. Defaults to None.
+
--
+@flashinfer_api
+def mxfp4_quantize(
+    a: torch.Tensor,
+    backend: str = "cuda",
+    enable_pdl: Optional[bool] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize input tensor to MXFP4 format.
+
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16.
+        backend (str, optional): Backend to use for quantization.
+            - "cuda": Use CUDA kernel (default, stable)
+            - "cute-dsl": Use CuTe-DSL kernel (requires SM100+, **experimental**)
+        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic
+            Dependent Launch). Only used when backend="cute-dsl".
+            If None, automatically detects based on device capability.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+            - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
--
+@flashinfer_api
+def mxfp4_dequantize(a_fp4, a_sf):
+    """
+    Dequantize input tensor from MXFP4 format.
+
+    Parameters:
+        a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
+        a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
+
+    Returns:
+        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
+    """
+    return e2m1_and_ufp8sf_scale_to_float(
+        a_fp4.cpu().view(torch.uint8),
+        a_sf.cpu().view(torch.uint8).reshape(-1),
+        torch.tensor([1.0], device=a_fp4.device),
+        32,
+        0,
+        True,
+    )
+
--
+@flashinfer_api
+def mxfp4_dequantize_host(
+    weight: torch.Tensor,
+    scale: torch.Tensor,
+    group_size: int = 32,
+) -> torch.Tensor:
+    """
+    Dequantize input tensor from MXFP4 format on host.
+
+    Parameters:
+        weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2)
+        scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8)
+        group_size (int, optional): Group size for dequantization. Defaults to 32.
+
+    Returns:
+        torch.Tensor: Dequantized tensor of shape [M, K] with dtype float.
+    """
+    # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future
+    major, minor = get_compute_capability(
+        torch.device("cuda:0")
+    )  # use any cuda device to get a compute capability
--
+@flashinfer_api
+def nvfp4_batched_quantize(
+    a,
+    a_global_sf,
+    sf_vec_size=16,
+):
+    """
+    Quantize batched input tensor to NVFP4 format.
+
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
+        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
+        sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
+            - Scale factors tensor with shape determined by layout and sf_vec_size
+    """
+    major, minor = get_compute_capability(a.device)
+    device_arch = f"{major * 10 + minor}"
--
+@flashinfer_api
+def nvfp4_quantize_paged_kv_cache(
+    k_cache: torch.Tensor,
+    v_cache: torch.Tensor,
+    kv_layout: str = "HND",
+    k_global_sf: Optional[torch.Tensor] = None,
+    v_global_sf: Optional[torch.Tensor] = None,
+) -> Tuple[
+    Tuple[torch.Tensor, torch.Tensor],
+    Tuple[torch.Tensor, torch.Tensor],
+    float,
+    float,
+]:
+    """Quantize paged KV cache to NVFP4 format for trtllm-gen MHA.
+
+    Quantizes BF16/FP16 K/V caches to NVFP4 with two-level scaling
+    (global FP32 + per-block FP8), and swizzles scale factors
+    for the SM100 trtllm-gen MHA kernel layout.
+
+    Args:
+        k_cache: Key cache tensor.
--
+@flashinfer_api
+def scaled_fp4_grouped_quantize(
+    a,
+    mask,
+    a_global_sf,
+):
+    """
+    quantize batched input tensor to NVFP4 format with mask.
+    Parameters:
+        a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
+        a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
+        mask (torch.Tensor): Mask tensor to apply before quantization.
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+            - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
+            - Scale factors tensor with shape determined by layout and sf_vec_size
+    """
+    major, minor = get_compute_capability(a.device)
+    device_arch = f"{major * 10 + minor}"
+    a_fp4, a_sf = get_fp4_quantization_module(
+        device_arch
--
+@flashinfer_api
+def nvfp4_kv_dequantize(
+    fp4_data: torch.Tensor,
+    block_scales: torch.Tensor,
+    global_scale: torch.Tensor,
+    output_dtype: torch.dtype = torch.bfloat16,
+) -> torch.Tensor:
+    """GPU dequantization of NVFP4 KV cache data with linear block scale layout.
+
+    Requires SM80+.
+
+    Args:
+        fp4_data (torch.Tensor): Packed FP4 data of shape ``[M, K/2]`` with dtype uint8.
+        block_scales (torch.Tensor): Per-block FP8 E4M3 scales of shape ``[M, K/16]``
+            with dtype uint8.
+        global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32,
+            on the same CUDA device as fp4_data.
+        output_dtype (torch.dtype): Output dtype, either ``torch.bfloat16`` or ``torch.float16``.
+
+    Returns:
+        torch.Tensor: Dequantized tensor of shape ``[M, K]`` with the specified output dtype.
--
+@flashinfer_api
+def nvfp4_kv_quantize(
+    input: torch.Tensor,
+    global_scale: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """GPU quantization to NVFP4 KV cache format with linear block scale layout.
+
+    Requires SM100+ (Blackwell) for the cvt.rn.satfinite.e2m1x2.f32 PTX instruction.
+
+    Args:
+        input (torch.Tensor): Input tensor of shape [M, K] with dtype bf16 or fp16.
+            K must be divisible by 16.
+        global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32,
+            on the same CUDA device as input.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]:
+            - fp4_output: Packed FP4 data of shape ``[M, K/2]`` with dtype uint8.
+            - block_scales: Per-block FP8 E4M3 scales of shape ``[M, K/16]`` with dtype uint8.
+    """
+    M, K = input.shape
--
+@flashinfer_api
+def mxfp8_quantize(
+    input: torch.Tensor,
+    is_sf_swizzled_layout: bool = True,
+    alignment: int = 32,
+    enable_pdl: Optional[bool] = None,
+    backend: Literal["cuda", "cute-dsl"] = "cuda",
+    sf_swizzle_layout: Optional[SfLayout] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Quantize input tensor to MxFP8 format.
+
+    This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format
+    with associated scale factors. It supports various input data types and scale factor layouts.
+
+    Args:
+        input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized.
+        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
+        alignment (int, optional): sfVecSize. Defaults to 32.
+        enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch).
+            If None, automatically detects based on device capability (SM >= 9.0). Defaults to None.
+        backend (Literal["cuda", "cute-dsl"], optional): Backend to use for quantization. Options are:
--
+@flashinfer_api
+def mxfp8_dequantize_host(
+    input: torch.Tensor,
+    scale_tensor: torch.Tensor,
+    is_sf_swizzled_layout: bool = True,
+    sf_swizzle_layout: Optional[SfLayout] = None,
+) -> torch.Tensor:
+    """Dequantize input tensor from MxFP8 format.
+
+    This function performs dequantization by converting a packed FP8 tensor in MxFP8 format
+    back to float values using the associated scale factors.
+
+    Args:
+        input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3.
+        scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size.
+        is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True.
+        sf_swizzle_layout (Optional[SfLayout], optional): Swizzle layout for scale factors.
+            If provided,it overrides is_sf_swizzled_layout. Defaults to None.
+            Available options are 1. SfLayout.layout_128x4; 2. SfLayout.layout_linear.
+
+    Returns:
--
+@flashinfer_api
+def mxfp4_quantize_cute_dsl(
+    input: torch.Tensor,
+    enable_pdl: bool | None = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize input tensor to MXFP4 format using CuTe-DSL kernel.
+
+    This is a GPU implementation matching FlashInfer's mxfp4_quantize() behavior:
+    - Global scale computed as (448 * 6) / max(|input|)
+    - UE8M0 scale factors
+    - E2M1 output format (4-bit, 2 values per byte)
+    - Swizzled (128x4) scale factor layout
+
+    The kernel is compiled once per (K, dtype, pdl) combination and handles
+    varying M (batch size) at runtime without recompilation.
+
+    Args:
+        input: Input tensor of shape [M, K] with dtype fp16/bf16
+        enable_pdl: Whether to enable PDL (Programmatic Dependent Launch).
+            If None, automatically detects based on device capability (SM >= 9.0).
--
+@flashinfer_api
+def mxfp8_quantize_cute_dsl(
+    input: torch.Tensor,
+    is_sf_swizzled_layout: bool = True,
+    alignment: int = 32,
+    enable_pdl: bool | None = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantize input tensor to MXFP8 format using CuTe-DSL kernel.
+
+    This is a GPU implementation with dual-path optimization:
+    - LINEAR layout: SF-block based iteration (fast)
+    - SWIZZLED layout: Row-based iteration with padding fast path (optimized)
+
+    The kernel is compiled once per (K, dtype, pdl) combination and handles
+    varying M (batch size) at runtime without recompilation.
+
+    Args:
+        input: Input tensor of shape [M, K] with dtype fp16/bf16
+        is_sf_swizzled_layout: Whether to use 128x4 swizzled layout (True) or linear (False)
+        alignment: Alignment for K dimension (default 32, must be multiple of SF_VEC_SIZE)
```


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Enhancements**
* Normalization now accepts scale as either a float or tensor; passing a
float emits a deprecation warning and is auto-converted for
compatibility.
* Attention/decoding API: cache-scale parameters are now optional
keyword-only arguments with sensible defaults, simplifying common call
patterns.
* **Tests**
* Tests updated to match the adjusted attention/decoding call signature.
* **Chores**
  * Release version bumped to 0.6.7.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants