perf: Performance tune cute dsl RMSNorm variants#2777
perf: Performance tune cute dsl RMSNorm variants#2777bkryu merged 8 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughAdds SM-version-aware, cluster-based tiling and occupancy decisions, async copy (cp.async) paths, multi-CTA reduction and remote shared-memory utilities, and FP8 storage/conversion helpers across RMSNorm and fused Add+RMSNorm kernels and supporting utilities. Changes
Sequence Diagram(s)sequenceDiagram
participant Host
participant Compiler
participant Kernel as DeviceKernel
participant CTA as CTA[n]
participant MBar as mbarrier
Host->>Compiler: request kernel (contiguous?, sm_version)
Compiler-->>Host: compiled kernel + config (cluster_n, tv_layout)
Host->>Kernel: launch with grid, cluster_n, kernel args
Kernel->>CTA: CTAs execute per-cluster tiling and set block_rank
CTA->>CTA: load tiles (global → shared), optional cp.async
CTA->>MBar: participate in cluster reduction / write remote shared
MBar->>CTA: release after cluster reduction
CTA->>CTA: normalize, FP8 convert & store outputs
CTA->>Host: results in global memory
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the performance of various CuTe-DSL RMSNorm kernel variants, including Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces significant performance optimizations to the CuTe DSL RMSNorm kernels. The changes include multi-row processing per thread block, SM90+ cluster reductions, vectorized FP8 conversions, and a dual-path compilation strategy for contiguous and non-contiguous tensors. The refactoring improves code reuse by centralizing kernel configuration logic in RMSNormKernel.
My review focuses on the correctness and efficiency of the new implementations. I've identified a recurring performance issue where intermediate values are unnecessarily recomputed in the async copy path of several kernels. Addressing this should further improve the performance of these already impressive optimizations.
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/norm/kernels/fused_add_rmsnorm.py`:
- Around line 752-839: The stores use the compact linear index expression
Int32(actual_row * H + abs_col) which breaks when outputs are non-contiguous;
replace all uses of that expression in every FP8 store path (all calls to
get_ptr_as_int64(mY, Int32(...)) inside cvt_and_store_* and
cvt_and_store_*_hw/_sw) with a layout-aware coordinate→index lookup (e.g., call
the tensor/memory layout helper such as
mY_layout.coordinate_to_index(actual_row, abs_col) or the existing
coordinate_to_index(mY, row, col) helper) so the pointer is computed using the
actual mY layout when contiguous=False; apply this change for the vectorized
fast-paths (cvt_and_store_8xf32_to_e4m3_hw, cvt_and_store_4xf32_to_e4m3_hw,
cvt_and_store_2xf32_to_e4m3_hw) and the per-element slow-paths
(cvt_and_store_f32_to_e4m3_hw / cvt_and_store_f32_to_e4m3_sw).
In `@flashinfer/norm/kernels/rmsnorm.py`:
- Around line 954-1041: The FP8 store paths compute linear indices using
actual_row * H + abs_col which assumes a contiguous row-major layout; instead
use the matrix row stride (e.g., row_stride or the existing leading-dimension
variable used for mY) when forming the pointer. Update every occurrence of
get_ptr_as_int64(mY, Int32(actual_row * H + abs_col)) (and the per-element
Int32(actual_row * H + abs_col_e)) to compute Int32(actual_row * row_stride +
abs_col) (or Int32(actual_row * row_stride + abs_col_e)) so stores
(cvt_and_store_*_to_e4m3_hw and cvt_and_store_f32_to_e4m3_sw) use the proper
layout-aware address; apply this change in all branches (vec_size 8/4/2 and the
scalar loop) and keep the same clamp logic.
- Around line 101-103: The kernel tuning uses torch.cuda.current_device() while
sm_version comes from input.device, causing cross-GPU mismatches; fix by
switching device context to the input tensor's device when querying properties
and compiling kernels—use input.device instead of torch.cuda.current_device()
for calls like torch.cuda.get_device_properties and for evaluating
tile_bytes/use_async_copy, and wrap kernel compilation blocks with with
torch.cuda.device(input.device): (same pattern as flashinfer/decode.py) so
sm_version, shared_memory_per_block_optin checks, and any compile steps (the
places that set use_async_copy, evaluate tile_bytes, and invoke compilation in
rmsnorm.py and corresponding fused_add_rmsnorm.py) all run on the same CUDA
device.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 7c8ca37f-7df9-421b-8361-dbab503a9ffe
📒 Files selected for processing (3)
flashinfer/norm/kernels/fused_add_rmsnorm.pyflashinfer/norm/kernels/rmsnorm.pyflashinfer/norm/utils.py
There was a problem hiding this comment.
🧹 Nitpick comments (1)
flashinfer/norm/kernels/rmsnorm.py (1)
106-123: Minor inconsistency:_compute_cluster_nfallback differs fromFusedAddRMSNormKernel.This method returns a hardcoded
16if no cluster_n satisfies the SMEM constraint (line 123), whileFusedAddRMSNormKernel._compute_cluster_ntracks abest_fitand returns that instead (lines 127-139 in fused_add_rmsnorm.py). Thebest_fitapproach is slightly more robust since it ensures the returned cluster_n at least divides H evenly.This is unlikely to cause issues in practice since H is typically a power of 2, but consider aligning with the fused kernel's pattern for consistency.
🔧 Suggested fix to align with FusedAddRMSNormKernel pattern
`@staticmethod` def _compute_cluster_n(H: int, dtype: cutlass.Numeric, sm_version: int) -> int: """Compute optimal cluster size based on H and device shared memory.""" if sm_version < 90: return 1 props = torch.cuda.get_device_properties(torch.cuda.current_device()) max_smem_bytes = props.shared_memory_per_block_optin elem_size = dtype.width // 8 + best_fit = 1 for cluster_n in [1, 2, 4, 8, 16]: if H % cluster_n != 0: continue smem_needed = RMSNormKernel._estimate_smem_bytes(H, cluster_n, elem_size) if smem_needed <= max_smem_bytes: return cluster_n + if smem_needed <= max_smem_bytes and best_fit == 1: + best_fit = cluster_n - return 16 + return best_fit🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/norm/kernels/rmsnorm.py` around lines 106 - 123, RMSNormKernel._compute_cluster_n currently returns a hardcoded 16 when no cluster_n fits SMEM; change it to mirror FusedAddRMSNormKernel's behavior by tracking a best_fit (initialize to 1 or the largest divisor candidate that divides H) while iterating the candidate cluster_n list [1,2,4,8,16], update best_fit whenever a candidate divides H and has smem_needed <= max_smem_bytes (or when it divides H even if smem_needed > max to prefer larger divisors), and after the loop return best_fit instead of 16; keep using RMSNormKernel._estimate_smem_bytes, dtype.width//8 for elem_size, and the same device props lookup to locate the code.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@flashinfer/norm/kernels/rmsnorm.py`:
- Around line 106-123: RMSNormKernel._compute_cluster_n currently returns a
hardcoded 16 when no cluster_n fits SMEM; change it to mirror
FusedAddRMSNormKernel's behavior by tracking a best_fit (initialize to 1 or the
largest divisor candidate that divides H) while iterating the candidate
cluster_n list [1,2,4,8,16], update best_fit whenever a candidate divides H and
has smem_needed <= max_smem_bytes (or when it divides H even if smem_needed >
max to prefer larger divisors), and after the loop return best_fit instead of
16; keep using RMSNormKernel._estimate_smem_bytes, dtype.width//8 for elem_size,
and the same device props lookup to locate the code.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 6b612562-fbf7-409f-9de3-9ef69a067c6f
📒 Files selected for processing (2)
flashinfer/norm/kernels/fused_add_rmsnorm.pyflashinfer/norm/kernels/rmsnorm.py
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/norm/kernels/rmsnorm.py`:
- Around line 116-124: The fallback unconditionally returns 16 which can violate
the H % cluster_n == 0 invariant; update the selection logic in the routine that
iterates cluster_n (the loop using RMSNormKernel._estimate_smem_bytes) to first
collect cluster_n values that divide H, then return the first one with
smem_needed <= max_smem_bytes, and if none fit, return the largest cluster_n
from the divisors list (the largest valid divisor of H from [1,2,4,8,16])
instead of hardcoding 16 so divisibility is always preserved.
- Around line 1292-1310: Replace the silent-copy-prone reshape calls with view
for the 3D path: when len(shape) == 3, use input.view(M, H) and out.view(M, H)
instead of input.reshape(...) and out.reshape(...), so any non-contiguous
layouts raise an error instead of creating a hidden copy; keep the subsequent
is_contiguous() check on input_2d and out_2d and pass those variables to
_get_compiled_rmsnorm_kernel and kernel(...) as before (symbols: input_2d,
out_2d, input.reshape/out.reshape -> input.view/out.view, is_contiguous,
_get_compiled_rmsnorm_kernel, kernel).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 091757ba-a8f6-4ad2-8643-5950c49a2c70
📒 Files selected for processing (1)
flashinfer/norm/kernels/rmsnorm.py
| if len(shape) == 3: | ||
| M = shape[0] * shape[1] | ||
| input_2d = input.view(M, H) | ||
| out_2d = out.view(M, H) | ||
| input_2d = input.reshape(M, H) | ||
| out_2d = out.reshape(M, H) | ||
| else: | ||
| M = shape[0] | ||
| input_2d = input | ||
| out_2d = out | ||
|
|
||
| is_contiguous = input_2d.is_contiguous() and out_2d.is_contiguous() | ||
| kernel = _get_compiled_rmsnorm_kernel( | ||
| _torch_dtype_to_str(input.dtype), H, weight_bias, enable_pdl | ||
| _torch_dtype_to_str(input.dtype), | ||
| H, | ||
| weight_bias, | ||
| enable_pdl, | ||
| get_sm_version(input.device), | ||
| contiguous=is_contiguous, | ||
| ) | ||
| kernel(input_2d, weight, out_2d, M, eps) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify the current flattening path and check for 3D non-contiguous guard coverage.
sed -n '1288,1312p' flashinfer/norm/kernels/rmsnorm.py
echo "----"
rg -n "reshape\\(M, H\\)|view\\(M, H\\)|is_contiguous\\(" flashinfer/norm/kernels/rmsnorm.py -C 2
echo "----"
rg -n "test_norm\\(|contiguous" tests/utils/test_norm.py -C 3Repository: flashinfer-ai/flashinfer
Length of output: 5429
🏁 Script executed:
# Extract the test_qknorm function to see how it validates the output
sed -n '162,190p' tests/utils/test_norm.py
echo "----"
# Check what the kernel function signature and behavior expects
sed -n '1270,1320p' flashinfer/norm/kernels/rmsnorm.pyRepository: flashinfer-ai/flashinfer
Length of output: 2500
Use view() instead of reshape() for the 3D flattening path to avoid silent data loss.
Lines 1294–1295 use reshape(), which can silently create copies for certain non-contiguous 3D tensor layouts. If a copy is created, the kernel writes to a temporary, not the original out tensor. The subsequent is_contiguous check on line 1301 inspects the reshaped tensors (which are always contiguous if they exist), so it cannot detect this issue.
Replacing reshape() with view() will raise an explicit error if the reshape would require a copy, making the constraint clear and catching misuse before silent corruption occurs.
Proposed fix
if len(shape) == 3:
M = shape[0] * shape[1]
- input_2d = input.reshape(M, H)
- out_2d = out.reshape(M, H)
+ try:
+ input_2d = input.view(M, H)
+ out_2d = out.view(M, H)
+ except RuntimeError as e:
+ raise ValueError(
+ "rmsnorm_cute expects 3D tensors flattenable to (M, H) without copy; "
+ "call contiguous() before rmsnorm_cute for this layout."
+ ) from e🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/norm/kernels/rmsnorm.py` around lines 1292 - 1310, Replace the
silent-copy-prone reshape calls with view for the 3D path: when len(shape) == 3,
use input.view(M, H) and out.view(M, H) instead of input.reshape(...) and
out.reshape(...), so any non-contiguous layouts raise an error instead of
creating a hidden copy; keep the subsequent is_contiguous() check on input_2d
and out_2d and pass those variables to _get_compiled_rmsnorm_kernel and
kernel(...) as before (symbols: input_2d, out_2d, input.reshape/out.reshape ->
input.view/out.view, is_contiguous, _get_compiled_rmsnorm_kernel, kernel).
|
[SUCCESS] Pipeline #46031994: 9/20 passed |
|
I cancelled the pr test because the ci won't pass before #2781 lands, and please re-trigger the test after that pr get merged |
The existing RMSNormKernel reads input from global memory twice: once in
Phase 1 (to compute sum of squares) and again in Phase 2 (to compute the
normalised output). For small hidden sizes where the input row fits in
shared memory, we can cache it in smem after the first load and re-use it
in Phase 2, reducing global memory traffic from 4xd*sizeof(T) to
3xd*sizeof(T) per row (a 25% improvement).
New kernels added to include/flashinfer/norm.cuh:
- RMSNormSmemKernel<VEC_SIZE, T>: stores input (as T, not float) into
shared memory during Phase 1 using vectorised 128-bit stores, then
loads it back with vectorised 128-bit reads in Phase 2.
- RMSNormQuantSmemKernel<VEC_SIZE, T, O>: same optimisation applied to
the FP8-quantised variant.
Shared memory layout:
[0, align16): warp reduction buffer (float, 16-byte aligned)
[align16, ...): input cache (T, d elements)
Dispatch logic in RMSNorm(), RMSNormQuant(), and GemmaRMSNorm():
1. Try to set max dynamic smem to smem_size_smem via
cudaFuncSetAttribute; if it succeeds, launch the smem-caching kernel.
2. If smem is insufficient (hidden size too large), fall back to the
original two-pass global-memory kernel transparently.
For bfloat16, the smem-caching variant fits without extended smem for
d <= ~24000 (48 KB limit). Common model sizes (2880, 4096, 7168, 8192)
all fit comfortably.
Note: the new CuTe DSL kernels (PR #2777) already keep input in register
memory throughout and do not re-read global memory, so they are not
affected by this change. This optimisation targets the CUDA C++ fallback
path.
AI-assisted implementation.
Co-authored-by: Zihao Ye <yzh119@users.noreply.github.com>
6377688 to
3850674
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (2)
flashinfer/norm/kernels/rmsnorm.py (2)
1292-1310:⚠️ Potential issue | 🟠 MajorConsider using
view()instead ofreshape()for 3D tensor flattening.Lines 1294-1295 use
reshape()which can silently create copies for certain non-contiguous 3D layouts. If a copy is created, the kernel writes to a temporary tensor, not the originalout. The subsequentis_contiguous()check on the reshaped tensors won't detect this since reshaped copies are contiguous.Using
view()would raise an error for non-flattenable layouts, making the constraint explicit. This was flagged in a previous review but appears unaddressed.Suggested fix
if len(shape) == 3: M = shape[0] * shape[1] - input_2d = input.reshape(M, H) - out_2d = out.reshape(M, H) + input_2d = input.view(M, H) + out_2d = out.view(M, H)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/norm/kernels/rmsnorm.py` around lines 1292 - 1310, The reshape() calls when flattening a 3D input can create silent copies leading the kernel to write into temporaries; replace input.reshape(M, H) and out.reshape(M, H) with input.view(M, H) and out.view(M, H) in the branch where len(shape) == 3 so non-flattenable (non-viewable) layouts raise an error instead of producing a copy, keep the subsequent is_contiguous check on input_2d/out_2d, and ensure the kernel call (kernel(input_2d, weight, out_2d, M, eps)) still receives the view of the original tensors so writes go to the original out; reference symbols: input.reshape, out.reshape, input.view, out.view, input_2d/is_contiguous, _get_compiled_rmsnorm_kernel, kernel(...).
106-123:⚠️ Potential issue | 🟡 MinorFallback cluster_n=16 may still violate divisibility constraint.
The loop correctly filters candidates by
H % cluster_n == 0, but the fallback on line 123 unconditionally returns 16. If no valid cluster_n fits in SMEM and H is not divisible by 16 (e.g., H=1024 works, but H=3072 with all options exceeding SMEM), returning 16 would break the invariant.Although this scenario is rare (most cluster sizes would fit for reasonable H values), consider returning the last valid candidate that passed the divisibility check:
Suggested fix
`@staticmethod` def _compute_cluster_n(H: int, dtype: cutlass.Numeric, sm_version: int) -> int: """Compute optimal cluster size based on H and device shared memory.""" if sm_version < 90: return 1 props = torch.cuda.get_device_properties(torch.cuda.current_device()) max_smem_bytes = props.shared_memory_per_block_optin elem_size = dtype.width // 8 + valid_candidates = [c for c in [1, 2, 4, 8, 16] if H % c == 0] - for cluster_n in [1, 2, 4, 8, 16]: - if H % cluster_n != 0: - continue + for cluster_n in valid_candidates: smem_needed = RMSNormKernel._estimate_smem_bytes(H, cluster_n, elem_size) if smem_needed <= max_smem_bytes: return cluster_n - return 16 + return valid_candidates[-1] # Largest valid divisor🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/norm/kernels/rmsnorm.py` around lines 106 - 123, The fallback in RMSNormKernel._compute_cluster_n unconditionally returns 16 which can violate the H % cluster_n == 0 invariant; change the function to record the last candidate that passed the divisibility check (e.g., last_valid = None) while iterating the candidates [1,2,4,8,16], use RMSNormKernel._estimate_smem_bytes to test SMEM fit, and if none of the candidates fit SMEM return a safe divisible fallback (last_valid if set, otherwise 1) instead of always returning 16.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/norm/kernels/rmsnorm.py`:
- Around line 1292-1310: The reshape() calls when flattening a 3D input can
create silent copies leading the kernel to write into temporaries; replace
input.reshape(M, H) and out.reshape(M, H) with input.view(M, H) and out.view(M,
H) in the branch where len(shape) == 3 so non-flattenable (non-viewable) layouts
raise an error instead of producing a copy, keep the subsequent is_contiguous
check on input_2d/out_2d, and ensure the kernel call (kernel(input_2d, weight,
out_2d, M, eps)) still receives the view of the original tensors so writes go to
the original out; reference symbols: input.reshape, out.reshape, input.view,
out.view, input_2d/is_contiguous, _get_compiled_rmsnorm_kernel, kernel(...).
- Around line 106-123: The fallback in RMSNormKernel._compute_cluster_n
unconditionally returns 16 which can violate the H % cluster_n == 0 invariant;
change the function to record the last candidate that passed the divisibility
check (e.g., last_valid = None) while iterating the candidates [1,2,4,8,16], use
RMSNormKernel._estimate_smem_bytes to test SMEM fit, and if none of the
candidates fit SMEM return a safe divisible fallback (last_valid if set,
otherwise 1) instead of always returning 16.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c27c709a-8ef2-4ce6-b435-c9632fc1bf69
📒 Files selected for processing (3)
flashinfer/norm/kernels/fused_add_rmsnorm.pyflashinfer/norm/kernels/rmsnorm.pyflashinfer/norm/utils.py
|
/bot run |
|
[FAILED] Pipeline #46109078: 10/20 passed |
nv-yunzheq
left a comment
There was a problem hiding this comment.
Unit test looks good
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Rewrites all CuTe-DSL RMSNorm kernel variants (`rmsnorm`, `gemma_rmsnorm`, `fused_add_rmsnorm`, `gemma_fused_add_rmsnorm`, `rmsnorm_quant`, `fused_add_rmsnorm_quant`, `qk_rmsnorm`, `gemma_qk_rmsnorm`) **Key changes:** * Multi-row blocks with async global-to-shared copy (cpasync): Each thread block processes multiple rows, improving wave utilization and hiding memory latency. Falls back to synchronous copies when alignment or shared memory constraints prevent async usage. * Cluster reduction on SM90+: For large hidden sizes (H > max single-CTA capacity), the workload is split across a CTA cluster that reduces partial sums via shared memory, avoiding the need for a single CTA to handle the full row. * Vectorized FP8 convert+store PTX intrinsics `cvt.rn.satfinite.e4m3x2.f32`, dramatically improving quantization kernel throughput. * Occupancy-aware shared memory management * Non-contiguous tensor support without performance loss: Uses dual-path compilation — a compact kernel for contiguous inputs (optimal codegen) and a strided kernel for non-contiguous inputs (symbolic row strides). Runtime dispatch via is_contiguous() ensures zero overhead for the common contiguous case. <details> <summary>Click to see B200 performance comparison data (Peak 8 TB/s)</summary> **RMSNorm** Before: <img width="1905" height="1680" alt="before_rmsnorm_bfloat16_NVIDIA_B200" src="https://github.com/user-attachments/assets/15582140-f6df-4794-a4b4-2cc19d252dbb" /> After <img width="1905" height="1680" alt="after_heatmap_rmsnorm_bfloat16_NVIDIA_B200" src="https://github.com/user-attachments/assets/0d306806-36d2-4576-a6c2-9f4629f277f8" /> **QK RMSNorm** Before: <img width="1905" height="1680" alt="before_qk_rmsnorm_bfloat16_NVIDIA_B200" src="https://github.com/user-attachments/assets/71540b32-1df7-4772-94a7-b6b8c71080ee" /> After: <img width="1905" height="1680" alt="after_qk_rmsnorm_bfloat16_NVIDIA_B200" src="https://github.com/user-attachments/assets/04e95f62-73fe-43f4-b1a1-95eff234e379" /> **Add + RMSNorm + FP8 Quantize** Before: <img width="1905" height="1680" alt="before_fused_add_rmsnorm_quant_bfloat16_NVIDIA_B200" src="https://github.com/user-attachments/assets/7bdda617-2d20-4a05-b7fd-2e9e489acba7" /> After: <img width="1905" height="1680" alt="after_fused_add_rmsnorm_quant_bfloat16_NVIDIA_B200" src="https://github.com/user-attachments/assets/663fb2a5-45cf-4fab-a74b-dc338d7d8bd0" /> </details> <details> <summary>Click to see H200 performance comparison data (Peak 4.8 TB/s)</summary> **RMSNorm** Before: <img width="1905" height="1680" alt="before_rmsnorm_bfloat16_NVIDIA_H200" src="https://github.com/user-attachments/assets/42f63c06-8f6f-4ada-b6fd-e19de4ee32cc" /> After: <img width="1905" height="1680" alt="after_rmsnorm_bfloat16_NVIDIA_H200" src="https://github.com/user-attachments/assets/ae30fc58-159e-43b6-b108-850bf1711cad" /> **RMSNorm + FP8 Quantize** Before: <img width="1905" height="1680" alt="before_rmsnorm_quant_bfloat16_NVIDIA_H200" src="https://github.com/user-attachments/assets/52469123-6a5f-459a-ae0b-586a11370ac9" /> After: <img width="1905" height="1680" alt="after_rmsnorm_quant_bfloat16_NVIDIA_H200" src="https://github.com/user-attachments/assets/4a229d4a-10ea-4d89-985f-c0378c6554d4" /> **Add + RMSNorm + FP8 Quantize** Before: <img width="1905" height="1680" alt="before_fused_add_rmsnorm_quant_bfloat16_NVIDIA_H200" src="https://github.com/user-attachments/assets/78ac50aa-ae6a-4ea6-a585-0b326279e96b" /> After: <img width="1905" height="1680" alt="after_fused_add_rmsnorm_quant_bfloat16_NVIDIA_H200" src="https://github.com/user-attachments/assets/8268ffb8-0ee0-49b7-9353-8d0151002329" /> </details> ## 🔍 Related Issues <!-- Link any related issues here --> flashinfer-ai#2396 flashinfer-ai#2771 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * SM-version aware kernels and cluster-based tiling for multi-CTA execution * Contiguity-aware selection for compact vs. strided tensor paths * Hardware-accelerated FP8/E4M3 conversion and packed storage routines * New exposed utilities for device SM queries and cluster-backed reductions * **Improvements** * Async copy paths, expanded shared-memory and cluster-reduction support * Per-cluster memory/tiling estimation and improved multi-cluster reduction handling * Public APIs now accept an optional SM-version hint and infer/preserve contiguity <!-- end of auto-generated comment: release notes by coderabbit.ai -->
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Rewrites all CuTe-DSL RMSNorm kernel variants (`rmsnorm`, `gemma_rmsnorm`, `fused_add_rmsnorm`, `gemma_fused_add_rmsnorm`, `rmsnorm_quant`, `fused_add_rmsnorm_quant`, `qk_rmsnorm`, `gemma_qk_rmsnorm`) **Key changes:** * Multi-row blocks with async global-to-shared copy (cpasync): Each thread block processes multiple rows, improving wave utilization and hiding memory latency. Falls back to synchronous copies when alignment or shared memory constraints prevent async usage. * Cluster reduction on SM90+: For large hidden sizes (H > max single-CTA capacity), the workload is split across a CTA cluster that reduces partial sums via shared memory, avoiding the need for a single CTA to handle the full row. * Vectorized FP8 convert+store PTX intrinsics `cvt.rn.satfinite.e4m3x2.f32`, dramatically improving quantization kernel throughput. * Occupancy-aware shared memory management * Non-contiguous tensor support without performance loss: Uses dual-path compilation — a compact kernel for contiguous inputs (optimal codegen) and a strided kernel for non-contiguous inputs (symbolic row strides). Runtime dispatch via is_contiguous() ensures zero overhead for the common contiguous case. <details> <summary>Click to see B200 performance comparison data (Peak 8 TB/s)</summary> **RMSNorm** Before: <img width="1905" height="1680" alt="before_rmsnorm_bfloat16_NVIDIA_B200" src="https://github.com/user-attachments/assets/15582140-f6df-4794-a4b4-2cc19d252dbb" /> After <img width="1905" height="1680" alt="after_heatmap_rmsnorm_bfloat16_NVIDIA_B200" src="https://github.com/user-attachments/assets/0d306806-36d2-4576-a6c2-9f4629f277f8" /> **QK RMSNorm** Before: <img width="1905" height="1680" alt="before_qk_rmsnorm_bfloat16_NVIDIA_B200" src="https://github.com/user-attachments/assets/71540b32-1df7-4772-94a7-b6b8c71080ee" /> After: <img width="1905" height="1680" alt="after_qk_rmsnorm_bfloat16_NVIDIA_B200" src="https://github.com/user-attachments/assets/04e95f62-73fe-43f4-b1a1-95eff234e379" /> **Add + RMSNorm + FP8 Quantize** Before: <img width="1905" height="1680" alt="before_fused_add_rmsnorm_quant_bfloat16_NVIDIA_B200" src="https://github.com/user-attachments/assets/7bdda617-2d20-4a05-b7fd-2e9e489acba7" /> After: <img width="1905" height="1680" alt="after_fused_add_rmsnorm_quant_bfloat16_NVIDIA_B200" src="https://github.com/user-attachments/assets/663fb2a5-45cf-4fab-a74b-dc338d7d8bd0" /> </details> <details> <summary>Click to see H200 performance comparison data (Peak 4.8 TB/s)</summary> **RMSNorm** Before: <img width="1905" height="1680" alt="before_rmsnorm_bfloat16_NVIDIA_H200" src="https://github.com/user-attachments/assets/42f63c06-8f6f-4ada-b6fd-e19de4ee32cc" /> After: <img width="1905" height="1680" alt="after_rmsnorm_bfloat16_NVIDIA_H200" src="https://github.com/user-attachments/assets/ae30fc58-159e-43b6-b108-850bf1711cad" /> **RMSNorm + FP8 Quantize** Before: <img width="1905" height="1680" alt="before_rmsnorm_quant_bfloat16_NVIDIA_H200" src="https://github.com/user-attachments/assets/52469123-6a5f-459a-ae0b-586a11370ac9" /> After: <img width="1905" height="1680" alt="after_rmsnorm_quant_bfloat16_NVIDIA_H200" src="https://github.com/user-attachments/assets/4a229d4a-10ea-4d89-985f-c0378c6554d4" /> **Add + RMSNorm + FP8 Quantize** Before: <img width="1905" height="1680" alt="before_fused_add_rmsnorm_quant_bfloat16_NVIDIA_H200" src="https://github.com/user-attachments/assets/78ac50aa-ae6a-4ea6-a585-0b326279e96b" /> After: <img width="1905" height="1680" alt="after_fused_add_rmsnorm_quant_bfloat16_NVIDIA_H200" src="https://github.com/user-attachments/assets/8268ffb8-0ee0-49b7-9353-8d0151002329" /> </details> ## 🔍 Related Issues <!-- Link any related issues here --> flashinfer-ai#2396 flashinfer-ai#2771 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * SM-version aware kernels and cluster-based tiling for multi-CTA execution * Contiguity-aware selection for compact vs. strided tensor paths * Hardware-accelerated FP8/E4M3 conversion and packed storage routines * New exposed utilities for device SM queries and cluster-backed reductions * **Improvements** * Async copy paths, expanded shared-memory and cluster-reduction support * Per-cluster memory/tiling estimation and improved multi-cluster reduction handling * Public APIs now accept an optional SM-version hint and infer/preserve contiguity <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
📌 Description
Rewrites all CuTe-DSL RMSNorm kernel variants (
rmsnorm,gemma_rmsnorm,fused_add_rmsnorm,gemma_fused_add_rmsnorm,rmsnorm_quant,fused_add_rmsnorm_quant,qk_rmsnorm,gemma_qk_rmsnorm)Key changes:
cvt.rn.satfinite.e4m3x2.f32, dramatically improving quantization kernel throughput.Click to see B200 performance comparison data (Peak 8 TB/s)
RMSNorm
Before:


After
QK RMSNorm
Before:


After:
Add + RMSNorm + FP8 Quantize
Before:


After:
Click to see H200 performance comparison data (Peak 4.8 TB/s)
RMSNorm
Before:

After:

RMSNorm + FP8 Quantize
Before:


After:
Add + RMSNorm + FP8 Quantize
Before:


After:
🔍 Related Issues
#2396
#2771
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements