rmsnorm cache input single load#2562
Conversation
Summary of ChangesHello @vsabavat, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly optimizes the RMSNorm operations by intelligently leveraging shared memory to cache input data. By dynamically assessing shared memory availability, the system can choose the most efficient kernel execution path, thereby reducing global memory bandwidth pressure and improving performance. Additionally, it provides more flexible control over warp allocation for these kernels. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughAdds env-driven warp-count override and runtime warp-sizing, a cached shared-memory opt-in query, and a templated CACHE_INPUT option for RMSNorm/RMSNormQuant kernels so host dispatch picks cached/non-cached kernel variants and adjusts shared-memory at launch. (≤50 words) Changes
Sequence Diagram(s)mermaid Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches🧪 Generate unit tests (beta)
No actionable comments were generated in the recent review. 🎉 🧹 Recent nitpick comments
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/norm.cuh (1)
747-778:⚠️ Potential issue | 🟡 Minor
GemmaRMSNormhas inconsistent smem sizing and misses the input-caching optimization.Two concerns:
Inconsistent
smem_size: Line 756 allocatesnum_warps * sizeof(float), butRMSNorm(line 169-170) allocatesceil_div(num_warps, 4u) * 4u * sizeof(float)for the no-cache path. This is currently safe because the kernel only touchessmem[0..num_warps-1], but it's inconsistent with the updated pattern and fragile if the reduction ever reads into the padded region.Missing
CACHE_INPUTpath:RMSNormandRMSNormQuantwere updated with the input-caching dispatch (lines 195-207, 341-355), butGemmaRMSNormstill always launches withCACHE_INPUT=false. Was this intentional, or should it also benefit from the caching optimization?Suggested fix for consistency (at minimum)
- const uint32_t smem_size = num_warps * sizeof(float); + const uint32_t smem_reduce_elems = ceil_div(num_warps, 4u) * 4u; + const uint32_t smem_size = smem_reduce_elems * sizeof(float);
🧹 Nitpick comments (2)
include/flashinfer/norm.cuh (2)
37-51:strtouloverflow not handled.If the environment variable contains a value exceeding
ULONG_MAX,strtoulreturnsULONG_MAXwithout*end != env, so the check on line 45 passes and the value getsstatic_cast<int>on line 48, yielding implementation-defined behavior. Consider checkingerrno == ERANGEor adding an upper-bound clamp (e.g.,parsed > 1024).Suggested fix
+#include <cerrno> ... inline int GetRMSNormNumWarpsOverrideFromEnv() { static int num_warps_override = []() -> int { const char* env = std::getenv("FLASHINFER_RMSNORM_NUM_WARPS"); if (env == nullptr || env[0] == 0) { return 0; } char* end = nullptr; + errno = 0; unsigned long parsed = std::strtoul(env, &end, 10); - if (end == env || *end != 0 || parsed == 0) { + if (end == env || *end != 0 || parsed == 0 || errno == ERANGE || parsed > 1024) { return 0; } return static_cast<int>(parsed); }(); return num_warps_override; }
173-178: Consider caching the device shared memory limit.
cudaGetDevice+cudaDeviceGetAttributeare called on everyRMSNorm/RMSNormQuantinvocation (also at lines 319-323). For high-frequency call sites these add host-side overhead. Astaticcache (similar to the env-var pattern above) would avoid repeated queries.Sketch
+inline int GetMaxSmemPerBlock() { + static int max_smem = []() { + int device; + cudaGetDevice(&device); + int val; + cudaDeviceGetAttribute(&val, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + return val; + }(); + return max_smem; +}
There was a problem hiding this comment.
Code Review
The pull request introduces a significant optimization for the RMSNorm and RMSNormQuant kernels by caching input data in shared memory. This effectively reduces global memory traffic by avoiding a second load of the input tensor during the normalization pass. The implementation correctly handles shared memory limits by dynamically checking device attributes and falling back to the non-cached version when necessary. I have provided feedback on minor cleanup opportunities, such as removing unused variables and optimizing redundant CUDA API calls.
include/flashinfer/norm.cuh
Outdated
| int max_smem_per_block; | ||
| FLASHINFER_CUDA_CALL( | ||
| cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); |
There was a problem hiding this comment.
include/flashinfer/norm.cuh
Outdated
| if (cache_input) { | ||
| auto kernel = RMSNormKernel<VEC_SIZE, T, true>; | ||
| FLASHINFER_CUDA_CALL( | ||
| cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); |
There was a problem hiding this comment.
The call to cudaFuncSetAttribute is redundant if smem_size is less than or equal to the default 48KB. It is generally recommended to only call this when opting into larger shared memory sizes to avoid unnecessary driver calls.
if (smem_size > 48 * 1024) {
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
|
Addressed the review comments in commit
Validation on B200 (
No regression seen in the direct FlashInfer microbench versus previous tuned numbers. |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@include/flashinfer/norm.cuh`:
- Around line 75-94: The function GetRMSNormMaxSharedMemoryPerBlockOptin
currently caches a single value in static std::atomic<int> max_smem_per_block
and therefore returns the same shared-memory limit for all CUDA devices; change
the implementation to key the cache by device ordinal or avoid caching across
devices: call cudaGetDevice() first, then either query
cudaDeviceGetAttribute(cudaDevAttrMaxSharedMemoryPerBlockOptin, ...) on every
call (cheap and safe) or maintain a small per-device cache keyed by the device
ID (e.g., a fixed-size array or map of atomics indexed by the cudaGetDevice()
result) and store/load the per-device value instead of the single static
max_smem_per_block; ensure you still handle cudaGetDevice/cudaDeviceGetAttribute
failures by falling back to kDefaultSmemLimit.
🧹 Nitpick comments (1)
include/flashinfer/norm.cuh (1)
216-234: Consider extracting the duplicated cache/no-cache dispatch into a helper.The
if (cache_input) { ... } else { ... }block is repeated identically acrossRMSNorm,RMSNormQuant, andGemmaRMSNorm. A small templated lambda or helper function could eliminate ~50 lines of near-identical code and reduce the maintenance surface.
|
Addressed the new CodeRabbit major comment in commit Change:
This removes the single-device assumption and is safe for multi-GPU processes (tensor-parallel/multi-device runtime). Quick benchmark validation after this change (B200
No performance regression vs prior tuned results. |
Summary
This PR improves FlashInfer RMSNorm performance in two steps:
FLASHINFER_RMSNORM_NUM_WARPSRMSNorm,RMSNormQuant, andGemmaRMSNormCommits:
6caa8721Optimize RMSNorm by caching input to avoid second global readf7f1a437Tune RMSNorm launch warps to reduce reduction overheadPerformance (B200, same-node)
8192x8192, bf16
80.078 us75.491 us49.261 us46.744 us8192x2880, bf16
28.688 us18.341 us14.463 usValidation
rtol=1e-2,atol=1e-2)tests/utils/test_norm.py::test_norm[True-False-True-dtype0-8192-99]passedrmsnorm_quantsmoke test passedSummary by CodeRabbit
Performance Improvements
New Features