-
Notifications
You must be signed in to change notification settings - Fork 584
feat: support more head dim in RoPE kernel #2109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: support more head dim in RoPE kernel #2109
Conversation
Signed-off-by: Raayan Dhar [email protected] <[email protected]>
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdded a device helper to handle partial RoPE quantization chunks, replaced per-element vector loads/stores with guarded partial-chunk writes, refactored dynamic kernel dispatch, routed RoPE‑quantized flows through RopeQuantize, and expanded cos/sin cache tests with four new configurations. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Host
participant Dispatch as KernelDispatch
participant GPU
Note over Host,Dispatch: Host prepares params (head_dim, rotary_dim, no_rope_dim,...)
Host->>Dispatch: call launch routine
Dispatch-->>GPU: select & launch kernel (RopeQuantize / other) with computed vec_size/bdx/bdy
alt head_dim < rotary_dim
GPU-->>Host: error return
else RoPE-quantized path
GPU->>GPU: RopeQuantizeKernel runs
GPU->>GPU: call scale_store_partial_chunk for tail lanes (zero-pad, scale, store)
else Non-RoPE / full-chunk
GPU->>GPU: regular vector loads/stores
end
GPU-->>Host: results / completion
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~35 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @raayandhar, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the flexibility and robustness of the Rotary Positional Embedding (RoPE) implementation by enabling support for arbitrary head dimensions within the RopeQuantizeKernel. It introduces a mechanism to gracefully handle partial data chunks in non-RoPE dimensions and refactors the BatchQKApplyRotaryPosIdsCosSinCache function to utilize this improved kernel. These changes ensure correct processing for a broader range of model configurations and simplify future maintenance. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for arbitrary head dimensions in the RoPE kernel by introducing a new helper function scale_store_partial_chunk to handle partial memory chunks and refactoring BatchQKApplyRotaryPosIdsCosSinCache to use the more general RopeQuantize kernel. This is a good simplification that reduces code duplication.
However, I've found a critical issue in how the non-RoPE tensor slices are handled. The pointer arithmetic used to create q_nope_in and k_nope_in is incorrect for multi-dimensional tensors, which will lead to incorrect memory accesses. I've also included a couple of suggestions to improve code clarity in the new helper function.
The added tests are good, but they seem to be passing despite the critical issue, which might indicate a problem with the test setup or reference implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
include/flashinfer/pos_enc.cuh (1)
1052-1120: Consider clarifying the bdx template parameter usage.The kernel dispatch sets the template parameter
bdx=1(line 1097) while computing a runtimebdxvalue (line 1054). This works because therotary_dimargument is explicitly passed to the RoPE functions, overriding the defaultvec_size * bdx. However, this discrepancy could be confusing for maintainability.Consider either:
- Using the computed
bdxvalue as the template parameter (would require DISPATCH_BDX macro), or- Adding a comment explaining why the template bdx is set to 1 while runtime bdx varies
Example comment:
// Template bdx=1 because rotary_dim is explicitly passed to RoPE functions auto kernel = RopeQuantizeKernel<INTERLEAVE, vec_size, 1, DType, IdType, QuantType>;
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
include/flashinfer/pos_enc.cuh(7 hunks)tests/attention/test_rope.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (5)
tests/attention/test_rope.py (1)
303-306: LGTM! Test coverage expanded appropriately.The new test configurations effectively validate the partial chunk handling introduced in this PR. They cover various scenarios where
no_rope_dim < rope_dim, which exercises the newscale_store_partial_chunklogic for tail chunks.include/flashinfer/pos_enc.cuh (4)
546-551: Correct usage of partial chunk handling.The
chunk_validcalculation properly handles tail chunks whereno_rope_dimis not a multiple ofrope_dim. The logic correctly computes the number of valid elements in the current chunk and handles the case whereelem_offset >= no_rope_dimby settingchunk_valid = 0.Also applies to: 566-571
1140-1196: Dispatch logic is consistent across variants.The dynamic dispatch logic for
RopeQuantizeAppendPagedKVCache(GQA/MHA) andRopeQuantizeAppendPagedMLACachefollows the same pattern asRopeQuantize. Thetotal_blocks_ycalculation correctly accounts for the differences between GQA/MHA (includes V blocks) and MLA (no V blocks).Note: The same
bdxtemplate parameter concern mentioned in the previous comment applies here as well.Also applies to: 1214-1274
236-292: Verify performance impact with targeted benchmarks for partial chunk scenarios.The code logic is correct with proper boundary checks and zero-padding. However, verification confirms the author's concern: no performance data exists for this code path. The existing benchmark uses
head_size = rotary_dim(both 128), meaningno_rope_dim = 0, so it doesn't exercise the partial chunk handling that this function addresses.Before merging, run benchmarks with configurations where
no_rope_dim > 0andno_rope_dim < rope_dim(e.g.,head_dim=192, rope_dim=128, no_rope_dim=64) to quantify the performance impact of the element-by-element fallback path and zero-padding logic.
1286-1312: Performance verification requires manual benchmarking—the routing change logic is correct and well-tested.The routing to
RopeQuantizeis intentional, uniform across all callers, and thoroughly validated for correctness. Existing tests intests/attention/test_rope.pyverify the output against reference implementations for all relevant configurations (head_dim: 64/128/256, partial_rotary_factor: 0.25–1.0). However, the original review specifically requests performance profiling to detect regressions, which cannot be completed automatically in this environment—you must run performance benchmarks locally to measure kernel execution time and throughput across representative workloads.
yzh119
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, cc @kahyunnam for another look
|
/bot run |
Signed-off-by: Raayan Dhar [email protected] <[email protected]>
|
Is there an issue with CI? Seems like it has been running for 2 days now 😅 |
|
Hi @raayandhar the CI is finished (result not returned here for some reasons), the PR itself do not bring any regressions and should be ready to merge. I'm running the benchmarks and will merge it as long as there is no performance regression. |
This LGTM to me! I do wonder if we're adding some extra not-necessary overhead with still having the pointwise multiply by 1 ( I agree we can merge when benchmarking looks ok @yzh119 . Thanks @raayandhar for the contribution! |
|
There are indeed some performance regressions @raayandhar @kahyunnam : On H100, Before this PR: rope-latency:
seq_len FlashInfer Native vLLM
0 2.0 0.005936 0.062576 0.007968
1 4.0 0.005952 0.064256 0.008160
2 8.0 0.005888 0.069376 0.008128
3 16.0 0.006112 0.066160 0.008352
4 32.0 0.006240 0.066784 0.008576
5 64.0 0.006752 0.068608 0.009056
6 128.0 0.007808 0.075328 0.010464
7 256.0 0.009664 0.088256 0.012832
8 512.0 0.013472 0.115648 0.019904
9 1024.0 0.020896 0.170496 0.033728
10 2048.0 0.035712 0.290272 0.060896
11 4096.0 0.066240 0.523520 0.114400
12 8192.0 0.129952 0.985888 0.221632
13 16384.0 0.255168 1.897296 0.436032
14 32768.0 0.486576 3.715232 0.864640
15 65536.0 0.953376 7.342368 1.722112After: seq_len FlashInfer Native vLLM
0 2.0 0.005952 0.063488 0.007968
1 4.0 0.005952 0.064112 0.008128
2 8.0 0.005920 0.069440 0.008128
3 16.0 0.006272 0.067104 0.008384
4 32.0 0.006400 0.067552 0.008576
5 64.0 0.006688 0.068512 0.009056
6 128.0 0.007744 0.075424 0.010464
7 256.0 0.009760 0.088224 0.012832
8 512.0 0.013632 0.115712 0.019872
9 1024.0 0.021120 0.170720 0.033696
10 2048.0 0.036064 0.289760 0.060864
11 4096.0 0.066976 0.524288 0.114528
12 8192.0 0.128800 0.985664 0.221760
13 16384.0 0.259968 1.899248 0.435840
14 32768.0 0.621312 3.711968 0.864608
15 65536.0 1.758672 7.343424 1.722016 |
Oof ok, I will go and investigate. Could you share your benchmarking scripts? |
Signed-off-by: Raayan Dhar [email protected] <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
include/flashinfer/pos_enc.cuh (2)
236-290: Minor optimization opportunity in scale_store_partial_chunk.Lines 273-276 scale all
vec_sizeelements including zero-padded ones when the chunk is partial. While harmless (0 × scale = 0) and likely optimized by the compiler, you could skip scaling invalid elements for marginal gains:#pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { - vec[i] = vec[i] * scale; + uint32_t elem_idx = lane_elem_offset + i; + if (elem_idx < chunk_valid) { + vec[i] = vec[i] * scale; + } }
1213-1282: MLA dispatch looks correct; minor cleanup possible.The dispatch correctly handles MLA-specific requirements (
num_kv_heads=1, no V processing, MLA cache type).Lines 1245-1246 and 1262-1264 introduce duplicate stride variables:
size_t k_rope_in_stride_h_dup = k_rope_in_stride; size_t k_nope_in_stride_h_dup = k_nope_in_stride;These can be removed by directly assigning
k_rope_in_strideandk_nope_in_strideto the params struct fields. This minor cleanup would reduce verbosity.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
include/flashinfer/pos_enc.cuh(7 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (5)
include/flashinfer/pos_enc.cuh (5)
544-549: LGTM: Partial chunk handling for non-RoPE dimensions.The usage of
scale_store_partial_chunkcorrectly guards writes whenno_rope_dimis not a multiple ofrope_dim. Thechunk_validcalculation ensures only valid elements are written, preventing out-of-bounds access.Also applies to: 564-569
1294-1298: LGTM: Guard against invalid head_dim.The check ensures
head_dim >= rotary_dimbefore proceeding, preventing undefined behavior. Clear error message aids debugging.
1361-1380: Routing to RopeQuantize introduces known performance trade-off.For arbitrary head dimensions, the code routes through
RopeQuantizewithquant_scale_q=1.0fandquant_scale_kv=1.0f. This:
- Adds a multiply-by-1.0 operation per element (minor, likely optimized by compiler)
- Uses a more general kernel instead of the optimized fast-path kernels for standard dimensions
The pointer arithmetic (lines 1367-1372) is correct despite past review concerns. The base pointer offset by
rotary_dimcombined with full-tensor strides produces the correct element addresses for all (idx, head) combinations.Based on PR comments, performance regressions on H100 are known and under investigation with benchmark results.
1050-1118: LGTM: Dynamic dispatch supports arbitrary dimensions.The refactored dispatch logic correctly computes:
- Thread block dimensions ensuring at least
bdxthreads in x-dimension to coverrope_dimwith vectorization- At least 128 threads per block for occupancy
- Dynamic
no_rope_chunksbased onno_rope_dim / rope_dimratioThe launch configuration with programmatic stream serialization attribute is properly constructed.
1138-1195: LGTM: Consistent dispatch pattern for paged KV cache.The dispatch logic mirrors
RopeQuantizewith appropriate adjustments for cache append operations. Thetotal_blocks_ycorrectly includes V processing blocks for GQA/MHA.
|
I think the fundamental issue that leads to this perf gap is that the
Let me know your thoughts. |
📌 Description
With the new changes we should be able to support arbitrary head dim using the
RopeQuantizeKernel, and I have routed theBatchQKApplyRotaryPosIdsCosSinCacheto do so.🔍 Related Issues
#2104
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.). NOTE: There were a set of tests where I got this error:which I know is related to my system. Unfortunately I do not manage this system and it does not have docker, so trying to fix this is a bit difficult. Hopefully someone else can verify my tests or run CI. All other tests were passing, and all the failing tests had that error.
Reviewer Notes
Please let me know if there's a smarter way to get around this hack or if other tests should be updated. Also I think we should remove the older kernel but let me know if we should do otherwise. I also need to test perf.
Summary by CodeRabbit
Bug Fixes
Tests
✏️ Tip: You can customize this high-level summary in your review settings.