Skip to content

perf: Performance tune cute dsl RMSNorm variants#2777

Merged
bkryu merged 8 commits intoflashinfer-ai:mainfrom
bkryu:cute-dsl-rmsn-perf
Mar 17, 2026
Merged

perf: Performance tune cute dsl RMSNorm variants#2777
bkryu merged 8 commits intoflashinfer-ai:mainfrom
bkryu:cute-dsl-rmsn-perf

Conversation

@bkryu
Copy link
Collaborator

@bkryu bkryu commented Mar 13, 2026

📌 Description

Rewrites all CuTe-DSL RMSNorm kernel variants (rmsnorm, gemma_rmsnorm, fused_add_rmsnorm, gemma_fused_add_rmsnorm, rmsnorm_quant, fused_add_rmsnorm_quant, qk_rmsnorm, gemma_qk_rmsnorm)

Key changes:

  • Multi-row blocks with async global-to-shared copy (cpasync): Each thread block processes multiple rows, improving wave utilization and hiding memory latency. Falls back to synchronous copies when alignment or shared memory constraints prevent async usage.
  • Cluster reduction on SM90+: For large hidden sizes (H > max single-CTA capacity), the workload is split across a CTA cluster that reduces partial sums via shared memory, avoiding the need for a single CTA to handle the full row.
  • Vectorized FP8 convert+store PTX intrinsics cvt.rn.satfinite.e4m3x2.f32, dramatically improving quantization kernel throughput.
  • Occupancy-aware shared memory management
  • Non-contiguous tensor support without performance loss: Uses dual-path compilation — a compact kernel for contiguous inputs (optimal codegen) and a strided kernel for non-contiguous inputs (symbolic row strides). Runtime dispatch via is_contiguous() ensures zero overhead for the common contiguous case.
Click to see B200 performance comparison data (Peak 8 TB/s)

RMSNorm

Before:
before_rmsnorm_bfloat16_NVIDIA_B200
After
after_heatmap_rmsnorm_bfloat16_NVIDIA_B200

QK RMSNorm

Before:
before_qk_rmsnorm_bfloat16_NVIDIA_B200
After:
after_qk_rmsnorm_bfloat16_NVIDIA_B200

Add + RMSNorm + FP8 Quantize

Before:
before_fused_add_rmsnorm_quant_bfloat16_NVIDIA_B200
After:
after_fused_add_rmsnorm_quant_bfloat16_NVIDIA_B200

Click to see H200 performance comparison data (Peak 4.8 TB/s)

RMSNorm

Before:
before_rmsnorm_bfloat16_NVIDIA_H200

After:
after_rmsnorm_bfloat16_NVIDIA_H200

RMSNorm + FP8 Quantize

Before:
before_rmsnorm_quant_bfloat16_NVIDIA_H200
After:
after_rmsnorm_quant_bfloat16_NVIDIA_H200

Add + RMSNorm + FP8 Quantize

Before:
before_fused_add_rmsnorm_quant_bfloat16_NVIDIA_H200
After:
after_fused_add_rmsnorm_quant_bfloat16_NVIDIA_H200

🔍 Related Issues

#2396

#2771

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • SM-version aware kernels and cluster-based tiling for multi-CTA execution
    • Contiguity-aware selection for compact vs. strided tensor paths
    • Hardware-accelerated FP8/E4M3 conversion and packed storage routines
    • New exposed utilities for device SM queries and cluster-backed reductions
  • Improvements

    • Async copy paths, expanded shared-memory and cluster-reduction support
    • Per-cluster memory/tiling estimation and improved multi-cluster reduction handling
    • Public APIs now accept an optional SM-version hint and infer/preserve contiguity

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 13, 2026

📝 Walkthrough

Walkthrough

Adds SM-version-aware, cluster-based tiling and occupancy decisions, async copy (cp.async) paths, multi-CTA reduction and remote shared-memory utilities, and FP8 storage/conversion helpers across RMSNorm and fused Add+RMSNorm kernels and supporting utilities.

Changes

Cohort / File(s) Summary
Fused Add + RMSNorm kernels
flashinfer/norm/kernels/fused_add_rmsnorm.py
Adds sm_version propagation, cluster_n computation and per-cluster tiling (H_per_cta), cluster-aware TV/layout generation, async-copy (cp.async) paths and shared buffers, two-pass load+compute then normalize/store flow, cluster reductions/synchronization, contiguity-aware compilation paths, and FP8 vectorized store helpers. Kernel constructors accept optional sm_version.
RMSNorm kernels
flashinfer/norm/kernels/rmsnorm.py
Introduces sm_version, cluster-aware tiling and per-cluster reductions (mbarrier/cluster sync), updated tiler/TV layout helpers, async-copy decisioning, 3D→2D tiling, cluster-scaled grid/launch logic, contiguity-aware compiled kernel selection, and FP8 quant paths. Constructors and compiled-kernel getters updated to accept sm_version and contiguous.
CUDA / utility primitives
flashinfer/norm/utils.py
Adds hardware-assisted FP8/E4M3 conversions & packed store intrinsics (cvt_and_store_*_e4m3_hw), cached get_sm_version, cluster/remote-shared helpers (set_block_rank, store_shared_remote, elem_pointer), multi-row/block/cluster reduction helpers (block_reduce_multirow, cluster_reduce_multirow, row_reduce_sum_multirow), and exposes these in __all__.

Sequence Diagram(s)

sequenceDiagram
participant Host
participant Compiler
participant Kernel as DeviceKernel
participant CTA as CTA[n]
participant MBar as mbarrier

Host->>Compiler: request kernel (contiguous?, sm_version)
Compiler-->>Host: compiled kernel + config (cluster_n, tv_layout)
Host->>Kernel: launch with grid, cluster_n, kernel args
Kernel->>CTA: CTAs execute per-cluster tiling and set block_rank
CTA->>CTA: load tiles (global → shared), optional cp.async
CTA->>MBar: participate in cluster reduction / write remote shared
MBar->>CTA: release after cluster reduction
CTA->>CTA: normalize, FP8 convert & store outputs
CTA->>Host: results in global memory
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested labels

op: norm, cute-dsl

Suggested reviewers

  • yzh119
  • kahyunnam
  • jimmyzho
  • nv-yunzheq
  • cyx-6

Poem

🐰
I hop through clustered tiles at dawn,
SM whispers help me find my prong.
Async crumbs and shared-memory tune,
Reductions hum beneath the moon.
FP8 sparkles — a rabbit's code-song.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 56.86% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: performance tuning of CuTe DSL RMSNorm variants across multiple kernel implementations.
Description check ✅ Passed PR description is comprehensive and follows the template structure with all required sections completed.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of various CuTe-DSL RMSNorm kernel variants, including rmsnorm, gemma_rmsnorm, fused_add_rmsnorm, gemma_fused_add_rmsnorm, rmsnorm_quant, fused_add_rmsnorm_quant, qk_rmsnorm, and gemma_qk_rmsnorm. The optimizations primarily focus on leveraging advanced GPU architectural features like CTA clusters and asynchronous memory operations on SM90+ devices, alongside improved FP8 quantization techniques and flexible tensor memory handling. These changes aim to provide substantial speedups, particularly for larger hidden dimensions and on newer hardware.

Highlights

  • Multi-row blocks with async global-to-shared copy (cpasync): Each thread block now processes multiple rows, enhancing wave utilization and memory latency hiding. The system intelligently falls back to synchronous copies when alignment or shared memory constraints prevent async usage.
  • Cluster reduction on SM90+: For large hidden sizes (H > max single-CTA capacity), the workload is now split across a CTA cluster. This cluster reduces partial sums via shared memory, eliminating the need for a single CTA to handle the entire row, leading to better scalability.
  • Vectorized FP8 convert+store PTX intrinsics: The implementation now uses cvt.rn.satfinite.e4m3x2.f32 and similar intrinsics, dramatically improving the throughput of quantization kernels by processing multiple FP32 values into FP8 simultaneously.
  • Occupancy-aware shared memory management: Shared memory allocation and usage patterns have been optimized to improve GPU occupancy, especially for fused-add kernels which require two shared memory tiles.
  • Non-contiguous tensor support without performance loss: The kernels now employ dual-path compilation: an optimal compact kernel for contiguous inputs and a strided kernel for non-contiguous inputs. Runtime dispatch ensures zero overhead for the common contiguous case, while still supporting flexible memory layouts.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/norm/kernels/fused_add_rmsnorm.py
    • Updated imports to include math and RMSNormKernel, and new vectorized FP8 conversion utilities and cluster reduction functions.
    • Refactored FusedAddRMSNormKernel and FusedAddRMSNormQuantKernel initializers to incorporate sm_version, cluster_n calculation, and dynamic use_async_copy determination.
    • Modified shared memory size calculation (_smem_size_in_bytes) to account for async copy buffers, reduction buffers, and mbarriers for cluster operations.
    • Adjusted kernel launch parameters (__call__ method) to support multi-row blocks and CTA clusters, including grid and cluster configurations.
    • Rewrote the core kernel logic to implement async global-to-shared memory copies, cluster reduction using row_reduce_sum_multirow, and SM90+ cluster synchronization primitives (mbarrier_init, cluster_arrive_relaxed, cluster_wait).
    • Enhanced FP8 quantization in FusedAddRMSNormQuantKernel.kernel by utilizing new vectorized PTX intrinsics for more efficient conversion and storage.
    • Updated compiled kernel retrieval functions (_get_compiled_fused_add_rmsnorm_kernel, _get_compiled_fused_add_rmsnorm_quant_kernel) to accept sm_version and contiguous flags, enabling dual-path compilation for optimal and strided tensor layouts.
    • Modified fused_add_rmsnorm_cute and fused_add_rmsnorm_quant_cute entry points to detect tensor contiguity and pass sm_version and contiguous information to the underlying kernels.
  • flashinfer/norm/kernels/rmsnorm.py
    • Updated imports to include math, vectorized FP8 conversion utilities, get_sm_version, and multi-row reduction functions, while removing older reduction and utility imports.
    • Refactored RMSNormKernel and RMSNormQuantKernel initializers to support sm_version, cluster_n determination, and use_async_copy logic, along with new static helper methods for thread and layout computation.
    • Modified shared memory size calculation (_smem_size_in_bytes) to dynamically adjust for async copy buffers, reduction buffers, and mbarriers based on cluster configuration.
    • Adjusted kernel launch parameters (__call__ method) to enable multi-row block processing and CTA cluster execution for improved scalability.
    • Rewrote the core kernel logic for RMSNormKernel and RMSNormQuantKernel to incorporate async global-to-shared memory copies, row_reduce_sum_multirow for cluster-aware reductions, and mbarrier synchronization.
    • Significantly refactored QKRMSNormKernel to remove num_warps as a parameter, update thread and vectorization calculations, and adapt its kernel to use multi-row blocks and async copies, replacing the previous warp-only reduction approach.
    • Enhanced FP8 quantization in RMSNormQuantKernel.kernel by integrating vectorized PTX intrinsics for more efficient FP8 conversion and storage.
    • Updated compiled kernel retrieval functions (_get_compiled_rmsnorm_kernel, _get_compiled_qk_rmsnorm_kernel, _get_compiled_rmsnorm_quant_kernel) to accept sm_version and contiguous flags, enabling specialized compilation paths.
    • Modified rmsnorm_cute, qk_rmsnorm_cute, and rmsnorm_quant_cute entry points to detect tensor contiguity (where applicable) and pass relevant device and layout information to the compiled kernels.
  • flashinfer/norm/utils.py
    • Added functools import for caching mechanisms.
    • Introduced new dsl_user_op functions: cvt_and_store_8xf32_to_e4m3_hw, cvt_and_store_4xf32_to_e4m3_hw, and cvt_and_store_2xf32_to_e4m3_hw for vectorized hardware-accelerated FP8 conversion and storage on SM89+ GPUs.
    • Implemented get_sm_version utility function, cached for efficiency, to retrieve the CUDA device's Streaming Multiprocessor version.
    • Added dsl_user_op functions for SM90+ cluster operations: set_block_rank for mapping shared memory pointers across CTAs, store_shared_remote for asynchronous writes to remote shared memory, and elem_pointer for obtaining element pointers.
    • Introduced new reduction utilities: block_reduce_multirow for block-level reductions with 2D buffers, cluster_reduce_multirow for cross-CTA cluster reductions using mbarriers, and row_reduce_sum_multirow to provide a unified interface for row-wise sum reductions with optional cluster support.
    • Updated the __all__ export list to include all newly added functions and utilities.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@bkryu bkryu self-assigned this Mar 13, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant performance optimizations to the CuTe DSL RMSNorm kernels. The changes include multi-row processing per thread block, SM90+ cluster reductions, vectorized FP8 conversions, and a dual-path compilation strategy for contiguous and non-contiguous tensors. The refactoring improves code reuse by centralizing kernel configuration logic in RMSNormKernel.

My review focuses on the correctness and efficiency of the new implementations. I've identified a recurring performance issue where intermediate values are unnecessarily recomputed in the async copy path of several kernels. Addressing this should further improve the performance of these already impressive optimizations.

@bkryu
Copy link
Collaborator Author

bkryu commented Mar 13, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !414 has been created, and the CI pipeline #46018154 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/norm/kernels/fused_add_rmsnorm.py`:
- Around line 752-839: The stores use the compact linear index expression
Int32(actual_row * H + abs_col) which breaks when outputs are non-contiguous;
replace all uses of that expression in every FP8 store path (all calls to
get_ptr_as_int64(mY, Int32(...)) inside cvt_and_store_* and
cvt_and_store_*_hw/_sw) with a layout-aware coordinate→index lookup (e.g., call
the tensor/memory layout helper such as
mY_layout.coordinate_to_index(actual_row, abs_col) or the existing
coordinate_to_index(mY, row, col) helper) so the pointer is computed using the
actual mY layout when contiguous=False; apply this change for the vectorized
fast-paths (cvt_and_store_8xf32_to_e4m3_hw, cvt_and_store_4xf32_to_e4m3_hw,
cvt_and_store_2xf32_to_e4m3_hw) and the per-element slow-paths
(cvt_and_store_f32_to_e4m3_hw / cvt_and_store_f32_to_e4m3_sw).

In `@flashinfer/norm/kernels/rmsnorm.py`:
- Around line 954-1041: The FP8 store paths compute linear indices using
actual_row * H + abs_col which assumes a contiguous row-major layout; instead
use the matrix row stride (e.g., row_stride or the existing leading-dimension
variable used for mY) when forming the pointer. Update every occurrence of
get_ptr_as_int64(mY, Int32(actual_row * H + abs_col)) (and the per-element
Int32(actual_row * H + abs_col_e)) to compute Int32(actual_row * row_stride +
abs_col) (or Int32(actual_row * row_stride + abs_col_e)) so stores
(cvt_and_store_*_to_e4m3_hw and cvt_and_store_f32_to_e4m3_sw) use the proper
layout-aware address; apply this change in all branches (vec_size 8/4/2 and the
scalar loop) and keep the same clamp logic.
- Around line 101-103: The kernel tuning uses torch.cuda.current_device() while
sm_version comes from input.device, causing cross-GPU mismatches; fix by
switching device context to the input tensor's device when querying properties
and compiling kernels—use input.device instead of torch.cuda.current_device()
for calls like torch.cuda.get_device_properties and for evaluating
tile_bytes/use_async_copy, and wrap kernel compilation blocks with with
torch.cuda.device(input.device): (same pattern as flashinfer/decode.py) so
sm_version, shared_memory_per_block_optin checks, and any compile steps (the
places that set use_async_copy, evaluate tile_bytes, and invoke compilation in
rmsnorm.py and corresponding fused_add_rmsnorm.py) all run on the same CUDA
device.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7c8ca37f-7df9-421b-8361-dbab503a9ffe

📥 Commits

Reviewing files that changed from the base of the PR and between e3aa638 and fd5035f.

📒 Files selected for processing (3)
  • flashinfer/norm/kernels/fused_add_rmsnorm.py
  • flashinfer/norm/kernels/rmsnorm.py
  • flashinfer/norm/utils.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
flashinfer/norm/kernels/rmsnorm.py (1)

106-123: Minor inconsistency: _compute_cluster_n fallback differs from FusedAddRMSNormKernel.

This method returns a hardcoded 16 if no cluster_n satisfies the SMEM constraint (line 123), while FusedAddRMSNormKernel._compute_cluster_n tracks a best_fit and returns that instead (lines 127-139 in fused_add_rmsnorm.py). The best_fit approach is slightly more robust since it ensures the returned cluster_n at least divides H evenly.

This is unlikely to cause issues in practice since H is typically a power of 2, but consider aligning with the fused kernel's pattern for consistency.

🔧 Suggested fix to align with FusedAddRMSNormKernel pattern
     `@staticmethod`
     def _compute_cluster_n(H: int, dtype: cutlass.Numeric, sm_version: int) -> int:
         """Compute optimal cluster size based on H and device shared memory."""
         if sm_version < 90:
             return 1

         props = torch.cuda.get_device_properties(torch.cuda.current_device())
         max_smem_bytes = props.shared_memory_per_block_optin
         elem_size = dtype.width // 8

+        best_fit = 1
         for cluster_n in [1, 2, 4, 8, 16]:
             if H % cluster_n != 0:
                 continue
             smem_needed = RMSNormKernel._estimate_smem_bytes(H, cluster_n, elem_size)
             if smem_needed <= max_smem_bytes:
                 return cluster_n
+            if smem_needed <= max_smem_bytes and best_fit == 1:
+                best_fit = cluster_n

-        return 16
+        return best_fit
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/norm/kernels/rmsnorm.py` around lines 106 - 123,
RMSNormKernel._compute_cluster_n currently returns a hardcoded 16 when no
cluster_n fits SMEM; change it to mirror FusedAddRMSNormKernel's behavior by
tracking a best_fit (initialize to 1 or the largest divisor candidate that
divides H) while iterating the candidate cluster_n list [1,2,4,8,16], update
best_fit whenever a candidate divides H and has smem_needed <= max_smem_bytes
(or when it divides H even if smem_needed > max to prefer larger divisors), and
after the loop return best_fit instead of 16; keep using
RMSNormKernel._estimate_smem_bytes, dtype.width//8 for elem_size, and the same
device props lookup to locate the code.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@flashinfer/norm/kernels/rmsnorm.py`:
- Around line 106-123: RMSNormKernel._compute_cluster_n currently returns a
hardcoded 16 when no cluster_n fits SMEM; change it to mirror
FusedAddRMSNormKernel's behavior by tracking a best_fit (initialize to 1 or the
largest divisor candidate that divides H) while iterating the candidate
cluster_n list [1,2,4,8,16], update best_fit whenever a candidate divides H and
has smem_needed <= max_smem_bytes (or when it divides H even if smem_needed >
max to prefer larger divisors), and after the loop return best_fit instead of
16; keep using RMSNormKernel._estimate_smem_bytes, dtype.width//8 for elem_size,
and the same device props lookup to locate the code.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6b612562-fbf7-409f-9de3-9ef69a067c6f

📥 Commits

Reviewing files that changed from the base of the PR and between fd5035f and 2b8ff35.

📒 Files selected for processing (2)
  • flashinfer/norm/kernels/fused_add_rmsnorm.py
  • flashinfer/norm/kernels/rmsnorm.py

@bkryu bkryu marked this pull request as draft March 13, 2026 01:28
@bkryu bkryu marked this pull request as ready for review March 13, 2026 05:04
@bkryu
Copy link
Collaborator Author

bkryu commented Mar 13, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !414 has been updated with latest changes, and the CI pipeline #46031994 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/norm/kernels/rmsnorm.py`:
- Around line 116-124: The fallback unconditionally returns 16 which can violate
the H % cluster_n == 0 invariant; update the selection logic in the routine that
iterates cluster_n (the loop using RMSNormKernel._estimate_smem_bytes) to first
collect cluster_n values that divide H, then return the first one with
smem_needed <= max_smem_bytes, and if none fit, return the largest cluster_n
from the divisors list (the largest valid divisor of H from [1,2,4,8,16])
instead of hardcoding 16 so divisibility is always preserved.
- Around line 1292-1310: Replace the silent-copy-prone reshape calls with view
for the 3D path: when len(shape) == 3, use input.view(M, H) and out.view(M, H)
instead of input.reshape(...) and out.reshape(...), so any non-contiguous
layouts raise an error instead of creating a hidden copy; keep the subsequent
is_contiguous() check on input_2d and out_2d and pass those variables to
_get_compiled_rmsnorm_kernel and kernel(...) as before (symbols: input_2d,
out_2d, input.reshape/out.reshape -> input.view/out.view, is_contiguous,
_get_compiled_rmsnorm_kernel, kernel).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 091757ba-a8f6-4ad2-8643-5950c49a2c70

📥 Commits

Reviewing files that changed from the base of the PR and between 2b8ff35 and 6377688.

📒 Files selected for processing (1)
  • flashinfer/norm/kernels/rmsnorm.py

Comment on lines 1292 to 1310
if len(shape) == 3:
M = shape[0] * shape[1]
input_2d = input.view(M, H)
out_2d = out.view(M, H)
input_2d = input.reshape(M, H)
out_2d = out.reshape(M, H)
else:
M = shape[0]
input_2d = input
out_2d = out

is_contiguous = input_2d.is_contiguous() and out_2d.is_contiguous()
kernel = _get_compiled_rmsnorm_kernel(
_torch_dtype_to_str(input.dtype), H, weight_bias, enable_pdl
_torch_dtype_to_str(input.dtype),
H,
weight_bias,
enable_pdl,
get_sm_version(input.device),
contiguous=is_contiguous,
)
kernel(input_2d, weight, out_2d, M, eps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify the current flattening path and check for 3D non-contiguous guard coverage.
sed -n '1288,1312p' flashinfer/norm/kernels/rmsnorm.py

echo "----"

rg -n "reshape\\(M, H\\)|view\\(M, H\\)|is_contiguous\\(" flashinfer/norm/kernels/rmsnorm.py -C 2

echo "----"

rg -n "test_norm\\(|contiguous" tests/utils/test_norm.py -C 3

Repository: flashinfer-ai/flashinfer

Length of output: 5429


🏁 Script executed:

# Extract the test_qknorm function to see how it validates the output
sed -n '162,190p' tests/utils/test_norm.py

echo "----"

# Check what the kernel function signature and behavior expects
sed -n '1270,1320p' flashinfer/norm/kernels/rmsnorm.py

Repository: flashinfer-ai/flashinfer

Length of output: 2500


Use view() instead of reshape() for the 3D flattening path to avoid silent data loss.

Lines 1294–1295 use reshape(), which can silently create copies for certain non-contiguous 3D tensor layouts. If a copy is created, the kernel writes to a temporary, not the original out tensor. The subsequent is_contiguous check on line 1301 inspects the reshaped tensors (which are always contiguous if they exist), so it cannot detect this issue.

Replacing reshape() with view() will raise an explicit error if the reshape would require a copy, making the constraint clear and catching misuse before silent corruption occurs.

Proposed fix
     if len(shape) == 3:
         M = shape[0] * shape[1]
-        input_2d = input.reshape(M, H)
-        out_2d = out.reshape(M, H)
+        try:
+            input_2d = input.view(M, H)
+            out_2d = out.view(M, H)
+        except RuntimeError as e:
+            raise ValueError(
+                "rmsnorm_cute expects 3D tensors flattenable to (M, H) without copy; "
+                "call contiguous() before rmsnorm_cute for this layout."
+            ) from e
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/norm/kernels/rmsnorm.py` around lines 1292 - 1310, Replace the
silent-copy-prone reshape calls with view for the 3D path: when len(shape) == 3,
use input.view(M, H) and out.view(M, H) instead of input.reshape(...) and
out.reshape(...), so any non-contiguous layouts raise an error instead of
creating a hidden copy; keep the subsequent is_contiguous() check on input_2d
and out_2d and pass those variables to _get_compiled_rmsnorm_kernel and
kernel(...) as before (symbols: input_2d, out_2d, input.reshape/out.reshape ->
input.view/out.view, is_contiguous, _get_compiled_rmsnorm_kernel, kernel).

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #46031994: 9/20 passed

@yongwww
Copy link
Member

yongwww commented Mar 13, 2026

I cancelled the pr test because the ci won't pass before #2781 lands, and please re-trigger the test after that pr get merged

claude bot pushed a commit that referenced this pull request Mar 13, 2026
The existing RMSNormKernel reads input from global memory twice: once in
Phase 1 (to compute sum of squares) and again in Phase 2 (to compute the
normalised output). For small hidden sizes where the input row fits in
shared memory, we can cache it in smem after the first load and re-use it
in Phase 2, reducing global memory traffic from 4xd*sizeof(T) to
3xd*sizeof(T) per row (a 25% improvement).

New kernels added to include/flashinfer/norm.cuh:
- RMSNormSmemKernel<VEC_SIZE, T>: stores input (as T, not float) into
  shared memory during Phase 1 using vectorised 128-bit stores, then
  loads it back with vectorised 128-bit reads in Phase 2.
- RMSNormQuantSmemKernel<VEC_SIZE, T, O>: same optimisation applied to
  the FP8-quantised variant.

Shared memory layout:
  [0, align16): warp reduction buffer (float, 16-byte aligned)
  [align16, ...): input cache (T, d elements)

Dispatch logic in RMSNorm(), RMSNormQuant(), and GemmaRMSNorm():
  1. Try to set max dynamic smem to smem_size_smem via
     cudaFuncSetAttribute; if it succeeds, launch the smem-caching kernel.
  2. If smem is insufficient (hidden size too large), fall back to the
     original two-pass global-memory kernel transparently.

For bfloat16, the smem-caching variant fits without extended smem for
d <= ~24000 (48 KB limit). Common model sizes (2880, 4096, 7168, 8192)
all fit comfortably.

Note: the new CuTe DSL kernels (PR #2777) already keep input in register
memory throughout and do not re-read global memory, so they are not
affected by this change. This optimisation targets the CUDA C++ fallback
path.

AI-assisted implementation.

Co-authored-by: Zihao Ye <yzh119@users.noreply.github.com>
@bkryu bkryu force-pushed the cute-dsl-rmsn-perf branch from 6377688 to 3850674 Compare March 14, 2026 01:58
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (2)
flashinfer/norm/kernels/rmsnorm.py (2)

1292-1310: ⚠️ Potential issue | 🟠 Major

Consider using view() instead of reshape() for 3D tensor flattening.

Lines 1294-1295 use reshape() which can silently create copies for certain non-contiguous 3D layouts. If a copy is created, the kernel writes to a temporary tensor, not the original out. The subsequent is_contiguous() check on the reshaped tensors won't detect this since reshaped copies are contiguous.

Using view() would raise an error for non-flattenable layouts, making the constraint explicit. This was flagged in a previous review but appears unaddressed.

Suggested fix
     if len(shape) == 3:
         M = shape[0] * shape[1]
-        input_2d = input.reshape(M, H)
-        out_2d = out.reshape(M, H)
+        input_2d = input.view(M, H)
+        out_2d = out.view(M, H)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/norm/kernels/rmsnorm.py` around lines 1292 - 1310, The reshape()
calls when flattening a 3D input can create silent copies leading the kernel to
write into temporaries; replace input.reshape(M, H) and out.reshape(M, H) with
input.view(M, H) and out.view(M, H) in the branch where len(shape) == 3 so
non-flattenable (non-viewable) layouts raise an error instead of producing a
copy, keep the subsequent is_contiguous check on input_2d/out_2d, and ensure the
kernel call (kernel(input_2d, weight, out_2d, M, eps)) still receives the view
of the original tensors so writes go to the original out; reference symbols:
input.reshape, out.reshape, input.view, out.view, input_2d/is_contiguous,
_get_compiled_rmsnorm_kernel, kernel(...).

106-123: ⚠️ Potential issue | 🟡 Minor

Fallback cluster_n=16 may still violate divisibility constraint.

The loop correctly filters candidates by H % cluster_n == 0, but the fallback on line 123 unconditionally returns 16. If no valid cluster_n fits in SMEM and H is not divisible by 16 (e.g., H=1024 works, but H=3072 with all options exceeding SMEM), returning 16 would break the invariant.

Although this scenario is rare (most cluster sizes would fit for reasonable H values), consider returning the last valid candidate that passed the divisibility check:

Suggested fix
     `@staticmethod`
     def _compute_cluster_n(H: int, dtype: cutlass.Numeric, sm_version: int) -> int:
         """Compute optimal cluster size based on H and device shared memory."""
         if sm_version < 90:
             return 1

         props = torch.cuda.get_device_properties(torch.cuda.current_device())
         max_smem_bytes = props.shared_memory_per_block_optin
         elem_size = dtype.width // 8

+        valid_candidates = [c for c in [1, 2, 4, 8, 16] if H % c == 0]
-        for cluster_n in [1, 2, 4, 8, 16]:
-            if H % cluster_n != 0:
-                continue
+        for cluster_n in valid_candidates:
             smem_needed = RMSNormKernel._estimate_smem_bytes(H, cluster_n, elem_size)
             if smem_needed <= max_smem_bytes:
                 return cluster_n

-        return 16
+        return valid_candidates[-1]  # Largest valid divisor
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/norm/kernels/rmsnorm.py` around lines 106 - 123, The fallback in
RMSNormKernel._compute_cluster_n unconditionally returns 16 which can violate
the H % cluster_n == 0 invariant; change the function to record the last
candidate that passed the divisibility check (e.g., last_valid = None) while
iterating the candidates [1,2,4,8,16], use RMSNormKernel._estimate_smem_bytes to
test SMEM fit, and if none of the candidates fit SMEM return a safe divisible
fallback (last_valid if set, otherwise 1) instead of always returning 16.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/norm/kernels/rmsnorm.py`:
- Around line 1292-1310: The reshape() calls when flattening a 3D input can
create silent copies leading the kernel to write into temporaries; replace
input.reshape(M, H) and out.reshape(M, H) with input.view(M, H) and out.view(M,
H) in the branch where len(shape) == 3 so non-flattenable (non-viewable) layouts
raise an error instead of producing a copy, keep the subsequent is_contiguous
check on input_2d/out_2d, and ensure the kernel call (kernel(input_2d, weight,
out_2d, M, eps)) still receives the view of the original tensors so writes go to
the original out; reference symbols: input.reshape, out.reshape, input.view,
out.view, input_2d/is_contiguous, _get_compiled_rmsnorm_kernel, kernel(...).
- Around line 106-123: The fallback in RMSNormKernel._compute_cluster_n
unconditionally returns 16 which can violate the H % cluster_n == 0 invariant;
change the function to record the last candidate that passed the divisibility
check (e.g., last_valid = None) while iterating the candidates [1,2,4,8,16], use
RMSNormKernel._estimate_smem_bytes to test SMEM fit, and if none of the
candidates fit SMEM return a safe divisible fallback (last_valid if set,
otherwise 1) instead of always returning 16.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c27c709a-8ef2-4ce6-b435-c9632fc1bf69

📥 Commits

Reviewing files that changed from the base of the PR and between 6377688 and 3850674.

📒 Files selected for processing (3)
  • flashinfer/norm/kernels/fused_add_rmsnorm.py
  • flashinfer/norm/kernels/rmsnorm.py
  • flashinfer/norm/utils.py

@bkryu
Copy link
Collaborator Author

bkryu commented Mar 14, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !414 has been updated with latest changes, and the CI pipeline #46109078 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46109078: 10/20 passed

Copy link
Collaborator

@nv-yunzheq nv-yunzheq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit test looks good

Copy link
Collaborator

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@bkryu bkryu merged commit f7322d9 into flashinfer-ai:main Mar 17, 2026
31 of 39 checks passed
@bkryu bkryu deleted the cute-dsl-rmsn-perf branch March 17, 2026 22:11
frankwang28 pushed a commit to frankwang28/flashinfer that referenced this pull request Mar 18, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Rewrites all CuTe-DSL RMSNorm kernel variants (`rmsnorm`,
`gemma_rmsnorm`, `fused_add_rmsnorm`, `gemma_fused_add_rmsnorm`,
`rmsnorm_quant`, `fused_add_rmsnorm_quant`, `qk_rmsnorm`,
`gemma_qk_rmsnorm`)

**Key changes:**
* Multi-row blocks with async global-to-shared copy (cpasync): Each
thread block processes multiple rows, improving wave utilization and
hiding memory latency. Falls back to synchronous copies when alignment
or shared memory constraints prevent async usage.
* Cluster reduction on SM90+: For large hidden sizes (H > max single-CTA
capacity), the workload is split across a CTA cluster that reduces
partial sums via shared memory, avoiding the need for a single CTA to
handle the full row.
* Vectorized FP8 convert+store PTX intrinsics
`cvt.rn.satfinite.e4m3x2.f32`, dramatically improving quantization
kernel throughput.
* Occupancy-aware shared memory management
* Non-contiguous tensor support without performance loss: Uses dual-path
compilation — a compact kernel for contiguous inputs (optimal codegen)
and a strided kernel for non-contiguous inputs (symbolic row strides).
Runtime dispatch via is_contiguous() ensures zero overhead for the
common contiguous case.



<details>
<summary>Click to see B200 performance comparison data (Peak 8
TB/s)</summary>

**RMSNorm**

Before:
<img width="1905" height="1680"
alt="before_rmsnorm_bfloat16_NVIDIA_B200"
src="https://github.com/user-attachments/assets/15582140-f6df-4794-a4b4-2cc19d252dbb"
/>
After
<img width="1905" height="1680"
alt="after_heatmap_rmsnorm_bfloat16_NVIDIA_B200"
src="https://github.com/user-attachments/assets/0d306806-36d2-4576-a6c2-9f4629f277f8"
/>

**QK RMSNorm**

Before:
<img width="1905" height="1680"
alt="before_qk_rmsnorm_bfloat16_NVIDIA_B200"
src="https://github.com/user-attachments/assets/71540b32-1df7-4772-94a7-b6b8c71080ee"
/>
After:
<img width="1905" height="1680"
alt="after_qk_rmsnorm_bfloat16_NVIDIA_B200"
src="https://github.com/user-attachments/assets/04e95f62-73fe-43f4-b1a1-95eff234e379"
/>

**Add + RMSNorm + FP8 Quantize**

Before:
<img width="1905" height="1680"
alt="before_fused_add_rmsnorm_quant_bfloat16_NVIDIA_B200"
src="https://github.com/user-attachments/assets/7bdda617-2d20-4a05-b7fd-2e9e489acba7"
/>
After:
<img width="1905" height="1680"
alt="after_fused_add_rmsnorm_quant_bfloat16_NVIDIA_B200"
src="https://github.com/user-attachments/assets/663fb2a5-45cf-4fab-a74b-dc338d7d8bd0"
/>

</details>

<details>
<summary>Click to see H200 performance comparison data (Peak 4.8
TB/s)</summary>

**RMSNorm**

Before:
<img width="1905" height="1680"
alt="before_rmsnorm_bfloat16_NVIDIA_H200"
src="https://github.com/user-attachments/assets/42f63c06-8f6f-4ada-b6fd-e19de4ee32cc"
/>

After:
<img width="1905" height="1680" alt="after_rmsnorm_bfloat16_NVIDIA_H200"
src="https://github.com/user-attachments/assets/ae30fc58-159e-43b6-b108-850bf1711cad"
/>

**RMSNorm + FP8 Quantize**

Before:
<img width="1905" height="1680"
alt="before_rmsnorm_quant_bfloat16_NVIDIA_H200"
src="https://github.com/user-attachments/assets/52469123-6a5f-459a-ae0b-586a11370ac9"
/>
After:
<img width="1905" height="1680"
alt="after_rmsnorm_quant_bfloat16_NVIDIA_H200"
src="https://github.com/user-attachments/assets/4a229d4a-10ea-4d89-985f-c0378c6554d4"
/>


**Add + RMSNorm + FP8 Quantize**

Before:
<img width="1905" height="1680"
alt="before_fused_add_rmsnorm_quant_bfloat16_NVIDIA_H200"
src="https://github.com/user-attachments/assets/78ac50aa-ae6a-4ea6-a585-0b326279e96b"
/>
After:
<img width="1905" height="1680"
alt="after_fused_add_rmsnorm_quant_bfloat16_NVIDIA_H200"
src="https://github.com/user-attachments/assets/8268ffb8-0ee0-49b7-9353-8d0151002329"
/>

</details>

## 🔍 Related Issues

<!-- Link any related issues here -->

flashinfer-ai#2396 

flashinfer-ai#2771 

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* SM-version aware kernels and cluster-based tiling for multi-CTA
execution
  * Contiguity-aware selection for compact vs. strided tensor paths
  * Hardware-accelerated FP8/E4M3 conversion and packed storage routines
* New exposed utilities for device SM queries and cluster-backed
reductions

* **Improvements**
* Async copy paths, expanded shared-memory and cluster-reduction support
* Per-cluster memory/tiling estimation and improved multi-cluster
reduction handling
* Public APIs now accept an optional SM-version hint and infer/preserve
contiguity
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Rewrites all CuTe-DSL RMSNorm kernel variants (`rmsnorm`,
`gemma_rmsnorm`, `fused_add_rmsnorm`, `gemma_fused_add_rmsnorm`,
`rmsnorm_quant`, `fused_add_rmsnorm_quant`, `qk_rmsnorm`,
`gemma_qk_rmsnorm`)

**Key changes:**
* Multi-row blocks with async global-to-shared copy (cpasync): Each
thread block processes multiple rows, improving wave utilization and
hiding memory latency. Falls back to synchronous copies when alignment
or shared memory constraints prevent async usage.
* Cluster reduction on SM90+: For large hidden sizes (H > max single-CTA
capacity), the workload is split across a CTA cluster that reduces
partial sums via shared memory, avoiding the need for a single CTA to
handle the full row.
* Vectorized FP8 convert+store PTX intrinsics
`cvt.rn.satfinite.e4m3x2.f32`, dramatically improving quantization
kernel throughput.
* Occupancy-aware shared memory management
* Non-contiguous tensor support without performance loss: Uses dual-path
compilation — a compact kernel for contiguous inputs (optimal codegen)
and a strided kernel for non-contiguous inputs (symbolic row strides).
Runtime dispatch via is_contiguous() ensures zero overhead for the
common contiguous case.

<details>
<summary>Click to see B200 performance comparison data (Peak 8
TB/s)</summary>

**RMSNorm**

Before:
<img width="1905" height="1680"
alt="before_rmsnorm_bfloat16_NVIDIA_B200"
src="https://github.com/user-attachments/assets/15582140-f6df-4794-a4b4-2cc19d252dbb"
/>
After
<img width="1905" height="1680"
alt="after_heatmap_rmsnorm_bfloat16_NVIDIA_B200"
src="https://github.com/user-attachments/assets/0d306806-36d2-4576-a6c2-9f4629f277f8"
/>

**QK RMSNorm**

Before:
<img width="1905" height="1680"
alt="before_qk_rmsnorm_bfloat16_NVIDIA_B200"
src="https://github.com/user-attachments/assets/71540b32-1df7-4772-94a7-b6b8c71080ee"
/>
After:
<img width="1905" height="1680"
alt="after_qk_rmsnorm_bfloat16_NVIDIA_B200"
src="https://github.com/user-attachments/assets/04e95f62-73fe-43f4-b1a1-95eff234e379"
/>

**Add + RMSNorm + FP8 Quantize**

Before:
<img width="1905" height="1680"
alt="before_fused_add_rmsnorm_quant_bfloat16_NVIDIA_B200"
src="https://github.com/user-attachments/assets/7bdda617-2d20-4a05-b7fd-2e9e489acba7"
/>
After:
<img width="1905" height="1680"
alt="after_fused_add_rmsnorm_quant_bfloat16_NVIDIA_B200"
src="https://github.com/user-attachments/assets/663fb2a5-45cf-4fab-a74b-dc338d7d8bd0"
/>

</details>

<details>
<summary>Click to see H200 performance comparison data (Peak 4.8
TB/s)</summary>

**RMSNorm**

Before:
<img width="1905" height="1680"
alt="before_rmsnorm_bfloat16_NVIDIA_H200"
src="https://github.com/user-attachments/assets/42f63c06-8f6f-4ada-b6fd-e19de4ee32cc"
/>

After:
<img width="1905" height="1680" alt="after_rmsnorm_bfloat16_NVIDIA_H200"
src="https://github.com/user-attachments/assets/ae30fc58-159e-43b6-b108-850bf1711cad"
/>

**RMSNorm + FP8 Quantize**

Before:
<img width="1905" height="1680"
alt="before_rmsnorm_quant_bfloat16_NVIDIA_H200"
src="https://github.com/user-attachments/assets/52469123-6a5f-459a-ae0b-586a11370ac9"
/>
After:
<img width="1905" height="1680"
alt="after_rmsnorm_quant_bfloat16_NVIDIA_H200"
src="https://github.com/user-attachments/assets/4a229d4a-10ea-4d89-985f-c0378c6554d4"
/>

**Add + RMSNorm + FP8 Quantize**

Before:
<img width="1905" height="1680"
alt="before_fused_add_rmsnorm_quant_bfloat16_NVIDIA_H200"
src="https://github.com/user-attachments/assets/78ac50aa-ae6a-4ea6-a585-0b326279e96b"
/>
After:
<img width="1905" height="1680"
alt="after_fused_add_rmsnorm_quant_bfloat16_NVIDIA_H200"
src="https://github.com/user-attachments/assets/8268ffb8-0ee0-49b7-9353-8d0151002329"
/>

</details>

## 🔍 Related Issues

<!-- Link any related issues here -->

flashinfer-ai#2396

flashinfer-ai#2771

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* SM-version aware kernels and cluster-based tiling for multi-CTA
execution
  * Contiguity-aware selection for compact vs. strided tensor paths
  * Hardware-accelerated FP8/E4M3 conversion and packed storage routines
* New exposed utilities for device SM queries and cluster-backed
reductions

* **Improvements**
* Async copy paths, expanded shared-memory and cluster-reduction support
* Per-cluster memory/tiling estimation and improved multi-cluster
reduction handling
* Public APIs now accept an optional SM-version hint and infer/preserve
contiguity
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants