Skip to content

rmsnorm cache input single load#2562

Closed
vsabavat wants to merge 4 commits intoflashinfer-ai:mainfrom
vsabavat:vsabavat/rmsnorm-cache-input-single-load
Closed

rmsnorm cache input single load#2562
vsabavat wants to merge 4 commits intoflashinfer-ai:mainfrom
vsabavat:vsabavat/rmsnorm-cache-input-single-load

Conversation

@vsabavat
Copy link

@vsabavat vsabavat commented Feb 14, 2026

Summary

This PR improves FlashInfer RMSNorm performance in two steps:

  1. Cache RMSNorm input in shared memory to avoid the second global-memory read (with fallback when shared memory is insufficient).
  2. Tune launch warps to reduce reduction overhead, with runtime override support:
    • FLASHINFER_RMSNORM_NUM_WARPS
    • applied to RMSNorm, RMSNormQuant, and GemmaRMSNorm

Commits:

  • 6caa8721 Optimize RMSNorm by caching input to avoid second global read
  • f7f1a437 Tune RMSNorm launch warps to reduce reduction overhead

Performance (B200, same-node)

  • 8192x8192, bf16

    • FlashInfer baseline: 80.078 us
    • After cache fix: 75.491 us
    • After tuning: 49.261 us
    • CTM reference: 46.744 us
    • Remaining gap vs CTM: +5.39%
  • 8192x2880, bf16

    • FlashInfer previous: 28.688 us
    • After tuning: 18.341 us
    • CTM reference: 14.463 us
    • Remaining gap vs CTM: +26.82%

Validation

  • Torch reference check passed (rtol=1e-2, atol=1e-2)
  • tests/utils/test_norm.py::test_norm[True-False-True-dtype0-8192-99] passed
  • rmsnorm_quant smoke test passed

Summary by CodeRabbit

  • Performance Improvements

    • RMSNorm and RMSNormQuant now use dynamic warp sizing and optional in-kernel input caching for better throughput and shared-memory usage.
    • Environment-driven overrides allow tuning of warp/config choices at runtime.
  • New Features

    • Runtime decision to enable or disable input caching per launch, improving adaptability across devices and workloads.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @vsabavat, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly optimizes the RMSNorm operations by intelligently leveraging shared memory to cache input data. By dynamically assessing shared memory availability, the system can choose the most efficient kernel execution path, thereby reducing global memory bandwidth pressure and improving performance. Additionally, it provides more flexible control over warp allocation for these kernels.

Highlights

  • RMSNorm Input Caching: Introduced a CACHE_INPUT template parameter to the RMSNormKernel and RMSNormQuantKernel to enable conditional caching of input data in shared memory, aiming to reduce global memory access.
  • Dynamic Shared Memory Allocation: Implemented logic to dynamically determine if input caching is feasible based on available shared memory per block, launching the appropriate kernel variant (CACHE_INPUT=true or false).
  • Warp Calculation Refinement: Added a new utility function GetRMSNormNumWarps to calculate the optimal number of warps for RMSNorm kernels, with an option to override this value via the FLASHINFER_RMSNORM_NUM_WARPS environment variable.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • include/flashinfer/norm.cuh
    • Added cstdlib header for std::getenv and std::strtoul.
    • Introduced GetRMSNormNumWarpsOverrideFromEnv function to read warp override from environment variable.
    • Implemented GetRMSNormNumWarps function to calculate optimal warps, incorporating environment variable override.
    • Modified RMSNormKernel and RMSNormQuantKernel templates to accept a CACHE_INPUT boolean parameter.
    • Allocated shared memory (smem_input) for caching input data within the kernels.
    • Updated kernel loops to conditionally load input data from shared memory if CACHE_INPUT is true.
    • Removed the block_size calculation and replaced direct num_warps calculation with a call to GetRMSNormNumWarps in RMSNorm and RMSNormQuant functions.
    • Added logic in RMSNorm and RMSNormQuant functions to determine cache_input based on shared memory capacity and launch the correct kernel specialization.
    • Adjusted shared memory size calculation to account for input caching.
Activity
  • No human activity (comments, reviews, or progress updates) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 14, 2026

📝 Walkthrough

Walkthrough

Adds env-driven warp-count override and runtime warp-sizing, a cached shared-memory opt-in query, and a templated CACHE_INPUT option for RMSNorm/RMSNormQuant kernels so host dispatch picks cached/non-cached kernel variants and adjusts shared-memory at launch. (≤50 words)

Changes

Cohort / File(s) Summary
RMSNorm Kernel & Launch Logic
include/flashinfer/norm.cuh
Add GetRMSNormNumWarpsOverrideFromEnv(), GetRMSNormNumWarps(d, vec_size), and GetRMSNormMaxSharedMemoryPerBlockOptin(); replace fixed warp sizing with runtime num_warps and adjust launch paths to use it.
Kernel Templates & In-kernel Caching
include/flashinfer/norm.cuh
Introduce bool CACHE_INPUT template parameter to RMSNormKernel and RMSNormQuantKernel; add smem_input backing and conditional load/store paths to support optional in-kernel input caching.
Host-side dispatch & shared-memory sizing
include/flashinfer/norm.cuh
Host computes num_warps, queries opt-in shared memory, computes smem_size with caching considered, and dispatches templated kernel variants (..., true or ..., false) based on feasibility and env override.
Headers / helpers / minor infra
include/flashinfer/norm.cuh
Added atomic/errno includes and inline helpers in flashinfer::norm; small adjustments to kernel implementations to support optional caching paths without changing public RMSNorm/RMSNormQuant signatures.

Sequence Diagram(s)

mermaid
sequenceDiagram
participant Host as Host (CPU)
participant Env as Env (process)
participant CUDA as CUDA Runtime
participant GPU as GPU Device
Host->>Env: read FLASHINFER_RMSNORM_NUM_WARPS
Host->>Host: GetRMSNormNumWarps(d, vec_size)
Host->>GPU: GetRMSNormMaxSharedMemoryPerBlockOptin()
Host->>Host: decide CACHE_INPUT feasibility, compute smem_size
Host->>CUDA: launch RMSNormKernel/CachedVariant(num_warps, smem_size)
CUDA->>GPU: schedule kernel with block/warp sizing
GPU->>GPU: kernel executes (conditional input caching / shared memory paths)
GPU-->>Host: completion / results

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • yzh119
  • IwakuraRein
  • kahyunnam
  • jiahanc

Poem

🐇 I read the env and count each warp,
I tuck inputs snug in shared-memory carp,
Template choices guide which path I roam,
Launches spring from host back to GPU home,
Hop—kernels run and bring results back sharp!

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'rmsnorm cache input single load' directly describes the main optimization: caching RMSNorm input to avoid redundant global memory reads.
Description check ✅ Passed The PR description includes a summary of changes, performance benchmarks, and validation results, but lacks some template sections like explicit issue links and pre-commit/test checklists.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Merge Conflict Detection ✅ Passed ✅ No merge conflicts detected when merging into main

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

No actionable comments were generated in the recent review. 🎉

🧹 Recent nitpick comments
include/flashinfer/norm.cuh (2)

113-113: Stale comment: num_warps can now reach 32 with the env override.

With FLASHINFER_RMSNORM_NUM_WARPS override, GetRMSNormNumWarps can return up to ceil_div(1024, 32) = 32. The reduction logic handles this correctly (warp 0 has exactly 32 lanes to cover 32 partial sums), but the comment is misleading. Same comment exists in RMSNormQuantKernel at line 256.

Suggested wording
-  // NOTE(Zihao): it's guaranteed that num_warps should be smaller than 32
+  // NOTE: num_warps is at most 32 (1024 threads / 32 per warp)

226-242: Consider extracting the kernel-select-and-launch pattern to reduce duplication.

The if (cache_input) { ... } else { ... } dispatch block is repeated nearly identically in RMSNorm, RMSNormQuant, and GemmaRMSNorm. A small templated helper or lambda could consolidate this:

auto launch = [&](auto kernel) {
  if (smem_size > 48 * 1024) {
    FLASHINFER_CUDA_CALL(
        cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  }
  FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, /* args... */));
};

Not urgent given the macro context, but would reduce ~60 lines of near-identical code across the three launchers.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/norm.cuh (1)

747-778: ⚠️ Potential issue | 🟡 Minor

GemmaRMSNorm has inconsistent smem sizing and misses the input-caching optimization.

Two concerns:

  1. Inconsistent smem_size: Line 756 allocates num_warps * sizeof(float), but RMSNorm (line 169-170) allocates ceil_div(num_warps, 4u) * 4u * sizeof(float) for the no-cache path. This is currently safe because the kernel only touches smem[0..num_warps-1], but it's inconsistent with the updated pattern and fragile if the reduction ever reads into the padded region.

  2. Missing CACHE_INPUT path: RMSNorm and RMSNormQuant were updated with the input-caching dispatch (lines 195-207, 341-355), but GemmaRMSNorm still always launches with CACHE_INPUT=false. Was this intentional, or should it also benefit from the caching optimization?

Suggested fix for consistency (at minimum)
-  const uint32_t smem_size = num_warps * sizeof(float);
+  const uint32_t smem_reduce_elems = ceil_div(num_warps, 4u) * 4u;
+  const uint32_t smem_size = smem_reduce_elems * sizeof(float);
🧹 Nitpick comments (2)
include/flashinfer/norm.cuh (2)

37-51: strtoul overflow not handled.

If the environment variable contains a value exceeding ULONG_MAX, strtoul returns ULONG_MAX without *end != env, so the check on line 45 passes and the value gets static_cast<int> on line 48, yielding implementation-defined behavior. Consider checking errno == ERANGE or adding an upper-bound clamp (e.g., parsed > 1024).

Suggested fix
+#include <cerrno>
 ...
 inline int GetRMSNormNumWarpsOverrideFromEnv() {
   static int num_warps_override = []() -> int {
     const char* env = std::getenv("FLASHINFER_RMSNORM_NUM_WARPS");
     if (env == nullptr || env[0] == 0) {
       return 0;
     }
     char* end = nullptr;
+    errno = 0;
     unsigned long parsed = std::strtoul(env, &end, 10);
-    if (end == env || *end != 0 || parsed == 0) {
+    if (end == env || *end != 0 || parsed == 0 || errno == ERANGE || parsed > 1024) {
       return 0;
     }
     return static_cast<int>(parsed);
   }();
   return num_warps_override;
 }

173-178: Consider caching the device shared memory limit.

cudaGetDevice + cudaDeviceGetAttribute are called on every RMSNorm / RMSNormQuant invocation (also at lines 319-323). For high-frequency call sites these add host-side overhead. A static cache (similar to the env-var pattern above) would avoid repeated queries.

Sketch
+inline int GetMaxSmemPerBlock() {
+  static int max_smem = []() {
+    int device;
+    cudaGetDevice(&device);
+    int val;
+    cudaDeviceGetAttribute(&val, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
+    return val;
+  }();
+  return max_smem;
+}

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces a significant optimization for the RMSNorm and RMSNormQuant kernels by caching input data in shared memory. This effectively reduces global memory traffic by avoiding a second load of the input tensor during the normalization pass. The implementation correctly handles shared memory limits by dynamically checking device attributes and falling back to the non-cached version when necessary. I have provided feedback on minor cleanup opportunities, such as removing unused variables and optimizing redundant CUDA API calls.

Comment on lines +175 to +177
int max_smem_per_block;
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling cudaDeviceGetAttribute on every kernel launch can introduce unnecessary overhead. Consider caching the max_smem_per_block value using a static variable or a thread-safe helper function, similar to how GetCudaMultiProcessorCount is implemented in flashinfer/utils.cuh.

if (cache_input) {
auto kernel = RMSNormKernel<VEC_SIZE, T, true>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The call to cudaFuncSetAttribute is redundant if smem_size is less than or equal to the default 48KB. It is generally recommended to only call this when opting into larger shared memory sizes to avoid unnecessary driver calls.

      if (smem_size > 48 * 1024) {
        FLASHINFER_CUDA_CALL(
            cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
      }

@vsabavat
Copy link
Author

Addressed the review comments in commit f60ce0af on this branch:

  • Added strtoul overflow handling (errno == ERANGE + int-range guard) in GetRMSNormNumWarpsOverrideFromEnv.
  • Added cached GetRMSNormMaxSharedMemoryPerBlockOptin() and switched RMSNorm/RMSNormQuant/GemmaRMSNorm to use it.
  • Updated GemmaRMSNorm shared-memory sizing to aligned reduce buffer and added cache/non-cache dispatch parity with RMSNorm.
  • Gated cudaFuncSetAttribute(...MaxDynamicSharedMemorySize...) behind smem_size > 48KB in RMSNorm paths.

Validation on B200 (pool0-0110, job 210220):

  • Direct FlashInfer microbench (same script as prior tuning):
    • 8192x8192: 49.254 us
    • 8192x2880: 18.209 us
  • Wrapper script check:
    • flashinfer_rmsnorm (TRT-LLM custom op path) 8192x8192: 80.025 us
    • ctm_rmsnorm 8192x8192: 46.779 us

No regression seen in the direct FlashInfer microbench versus previous tuned numbers.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@include/flashinfer/norm.cuh`:
- Around line 75-94: The function GetRMSNormMaxSharedMemoryPerBlockOptin
currently caches a single value in static std::atomic<int> max_smem_per_block
and therefore returns the same shared-memory limit for all CUDA devices; change
the implementation to key the cache by device ordinal or avoid caching across
devices: call cudaGetDevice() first, then either query
cudaDeviceGetAttribute(cudaDevAttrMaxSharedMemoryPerBlockOptin, ...) on every
call (cheap and safe) or maintain a small per-device cache keyed by the device
ID (e.g., a fixed-size array or map of atomics indexed by the cudaGetDevice()
result) and store/load the per-device value instead of the single static
max_smem_per_block; ensure you still handle cudaGetDevice/cudaDeviceGetAttribute
failures by falling back to kDefaultSmemLimit.
🧹 Nitpick comments (1)
include/flashinfer/norm.cuh (1)

216-234: Consider extracting the duplicated cache/no-cache dispatch into a helper.

The if (cache_input) { ... } else { ... } block is repeated identically across RMSNorm, RMSNormQuant, and GemmaRMSNorm. A small templated lambda or helper function could eliminate ~50 lines of near-identical code and reduce the maintenance surface.

@vsabavat
Copy link
Author

Addressed the new CodeRabbit major comment in commit 15c94452.

Change:

  • GetRMSNormMaxSharedMemoryPerBlockOptin() now caches by CUDA device ordinal (per-device atomic cache), instead of a single process-global value.
  • Keeps safe fallback to 48KB when CUDA queries fail.

This removes the single-device assumption and is safe for multi-GPU processes (tensor-parallel/multi-device runtime).

Quick benchmark validation after this change (B200 pool0-0077, job 210222):

  • shape=8192x8192: 49.361 us
  • shape=8192x2880: 18.472 us

No performance regression vs prior tuned results.

@vsabavat vsabavat marked this pull request as draft February 14, 2026 06:57
@vsabavat vsabavat closed this Feb 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant