-
Notifications
You must be signed in to change notification settings - Fork 576
use scalar for kv_scale in xqa #2033
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use scalar for kv_scale in xqa #2033
Conversation
Signed-off-by: Qidi Sang <[email protected]>
WalkthroughConverted kvCacheScale from a pointer/Tensor/TensorView to a plain scalar (float/double) across Python APIs, C++/CUDA bindings, kernel implementations, and tests; call sites and scale computations were updated to use direct scalar values instead of indexing into device memory or 1-element tensors. Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python
participant PyLayer as flashinfer/xqa.py
participant Binding as csrc/flashinfer_xqa_binding.cu
participant Wrapper as csrc/xqa/xqa_wrapper.cu
participant Kernel as CUDA kernels (mha/mla)
Note over Py,PyLayer: Old: passed 1-element tensor
Py->>PyLayer: xqa(..., kv_scale=torch.tensor([s]))
PyLayer->>Binding: call binding with TensorView kvCacheScale
Binding->>Wrapper: forward TensorView
Wrapper->>Kernel: launch(..., kvCacheScale=ptr)
Kernel->>Kernel: scale = kvCacheScale[0]
Note over Py,PyLayer: New: pass scalar float
Py->>PyLayer: xqa(..., kv_scale=1.0)
PyLayer->>Binding: call binding with double kvCacheScale
Binding->>Wrapper: forward double kvCacheScale
Wrapper->>Kernel: launch(..., kvCacheScale=float)
Kernel->>Kernel: scale = kvCacheScale
Estimated code review effortπ― 3 (Moderate) | β±οΈ ~25 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touchesβ Failed checks (1 warning)
β Passed checks (1 passed)
β¨ Finishing touches
π§ͺ Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request streamlines the handling of the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with π and π on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the kv_scale parameter to be a scalar value instead of a tensor across the XQA functionality. The changes are consistently applied to the CUDA kernels, C++ wrappers, and Python bindings. This is a good simplification that improves performance by avoiding a memory dereference in the CUDA kernels and makes the Python API cleaner. The code modifications are correct and the tests have been updated accordingly. This is a solid improvement to the codebase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
π§Ή Nitpick comments (3)
csrc/xqa/mha.cu (1)
1304-1306: Refresh the kvCacheScale commentNow that
kvCacheScaleis passed by value, the existing note about it being a βDevice memory scalarβ is misleading. Please update (and mirror any similar spots) so future readers donβt assume they still need to manage a device allocation here.csrc/xqa/mla_sm120.cu (1)
398-400: Tune the kvCacheScale commentSame minor nit as in the MHA path: this argument is no longer a device pointer, so the comment should be adjusted accordingly (and any other duplicates cleaned up).
csrc/xqa/mha.h (1)
117-119: Sync comment with the new calling conventionThis prototype (and the other overloads below) still advertises
kvCacheScaleas a βDevice memory scalar,β even though itβs now passed by value. Updating the wording here will keep the header consistent with the implementation.
π Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
π Files selected for processing (9)
csrc/flashinfer_xqa_binding.cu(2 hunks)csrc/xqa/mha.cu(6 hunks)csrc/xqa/mha.h(5 hunks)csrc/xqa/mha_sm90.cu(7 hunks)csrc/xqa/mla_sm120.cu(6 hunks)csrc/xqa/xqa_wrapper.cu(4 hunks)flashinfer/decode.py(1 hunks)flashinfer/xqa.py(8 hunks)tests/attention/test_xqa.py(4 hunks)
π§° Additional context used
𧬠Code graph analysis (1)
csrc/xqa/mha.h (1)
csrc/xqa/mla_sm120.cu (2)
launchMLA(1651-1761)launchMLA(1651-1662)
β° Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
|
/bot run |
|
[SUCCESS] Pipeline #37853197: 13/17 passed |
bkryu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes LGTM to me; @yzh119 can you take a quick look to give approval?
csrc/xqa/mha.cu
Outdated
| uint32_t const batchSize, | ||
| float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V | ||
| // cache. Used only for int8/fp8 KV cache. | ||
| float kvCacheScale, // Device memory scalar. Same scale for K and V |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment could be updated, right? It's no longer a device memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
tests/attention/test_xqa.py
Outdated
| seq_len_list.fill_(seq_len) | ||
|
|
||
| kv_cache_scale = torch.ones(1, dtype=torch.float32, device="cuda") | ||
| kv_cache_scale = 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also test the case where the value isnβt 1.0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Signed-off-by: Qidi Sang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and canβt be posted inline due to platform limitations.
β οΈ Outside diff range comments (1)
csrc/xqa/mha_sm90.cu (1)
1315-1334: Defensive check for zero kvCacheScale.rcpKScale/rcpVScale compute 1/kvCacheScale. If a caller passes 0, this will poison newβtoken writes. Add a debug assert or early return.
- float const rcpKScale = 1.F / kvCacheScale; + assert(kvCacheScale != 0.f && "kvCacheScale must be non-zero"); + float const rcpKScale = 1.F / kvCacheScale; ... - float const rcpVScale = 1.F / kvCacheScale; + assert(kvCacheScale != 0.f && "kvCacheScale must be non-zero"); + float const rcpVScale = 1.F / kvCacheScale;Also applies to: 1373-1389
β»οΈ Duplicate comments (3)
tests/attention/test_xqa.py (1)
183-185: Good: parametrize kv_scale and q_scale (also covers nonβ1.0).This directly addresses the earlier ask to test values not equal to 1.0. Looks solid.
Also applies to: 451-453
csrc/xqa/mha.cu (2)
2386-2415: Host kernel wrappers: scalar kvCacheScale endβtoβend.Signatures and launches look aligned.
Past note about βdevice memoryβ wording for kv scale is resolved here.
Also applies to: 2443-2476, 2547-2551, 2563-2576, 2618-2621
2443-2551: No additional action β see SM90 comment for repoβwide verification.Avoiding duplicate scripts; rely on earlier verification to catch stale pointer/indexing.
Also applies to: 2563-2622
π§Ή Nitpick comments (3)
tests/attention/test_xqa.py (3)
28-30: Guard CUDA device queries at import time.Accessing CUDA props at module import can raise on nonβCUDA nodes before pytest marks are evaluated. Move behind a check or into tests.
- props = torch.cuda.get_device_properties(0) - sm_count = props.multi_processor_count + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available", allow_module_level=True) + props = torch.cuda.get_device_properties(0) + sm_count = props.multi_processor_count
183-185: Optional: add an amplified scale (e.g., 2.0) to stress saturation paths.Broader coverage can catch fp8 saturation/regression when scaling up.
-@pytest.mark.parametrize("kv_scale", [1.0, 0.5]) +@pytest.mark.parametrize("kv_scale", [1.0, 0.5, 2.0]) -@pytest.mark.parametrize("q_scale", [1.0, 0.5]) +@pytest.mark.parametrize("q_scale", [1.0, 0.5, 2.0])Also applies to: 451-453
634-636: Replace hardcoded 576 with valid_elems_per_head_qk variable for maintainability.The variable
valid_elems_per_head_qkis in scope at line 634 and equals 576. Using the variable instead of the magic number prevents silent drift if head dimensions change and aligns with the reference function's use ofvalid_elems_per_headparameter.- q_scale=q_scale * math.sqrt(576), + q_scale=q_scale * math.sqrt(valid_elems_per_head_qk),
π Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
π Files selected for processing (5)
csrc/xqa/mha.cu(6 hunks)csrc/xqa/mha.h(6 hunks)csrc/xqa/mha_sm90.cu(7 hunks)csrc/xqa/mla_sm120.cu(6 hunks)tests/attention/test_xqa.py(8 hunks)
π§ Files skipped from review as they are similar to previous changes (1)
- csrc/xqa/mha.h
β° Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
π Additional comments (14)
tests/attention/test_xqa.py (1)
201-203: Skip nonβFP8 when kv_scale β 1 is correct.Kernel applies kv scale only for quantized KV; skipping this matrix of cases is appropriate.
csrc/xqa/mha_sm90.cu (4)
606-637: Kernel takes kvCacheScale as a scalar now β good.Signature update removes device reads and simplifies launches.
775-777: qkScale correctly gates kvCacheScale to quantized KV.Using kvCacheScale only when cache is int8/fp8 matches the intended behavior.
964-965: xvoScale update is consistent with qkScale gating.Scalar kvCacheScale applied only for quantized KV on the XV path.
2933-2934: Migration to scalar kvCacheScale verified successfully across all call sites.All kernel declarations (lines 2933, 3045), kernel invocations (lines 3019, 3098), and internal kernel logic (arithmetic operations at lines 775, 964, 1318, 1377) consistently use
float kvCacheScaleas a scalar. Python bindings pass floats. No pointer-style usages, dereferences, or indexing remain.csrc/xqa/mha.cu (3)
1270-1307: Kernel impl now takes scalar kvCacheScale β good cleanup.Removes unnecessary device indirection; fits with the gating on quantized KV.
1505-1507: qkScale update matches SM90 logic.Gated application of kvCacheScale is correct here too.
2158-2171: voScale includes kvCacheScale only for quantized KV.Consistent and correct; preserves FP16/BF16 behavior.
csrc/xqa/mla_sm120.cu (6)
398-398: LGTM! Simplified parameter from reference to scalar.The change from
float const&tofloatis appropriate for a scalar value and eliminates unnecessary indirection. This aligns with the PR's objective to simplify the kvCacheScale interface across the codebase.
451-451: LGTM! Updated to direct scalar access.The change from
args.kvCacheScale[0]toargs.kvCacheScalecorrectly reflects the parameter type change and simplifies the code.
1230-1230: LGTM! Scalar access updated correctly.Both uses of
kvCacheScalein the ternary expression have been correctly updated to direct scalar access.
1555-1555: LGTM! Kernel parameter simplified to scalar.The change from
float const* __restrict__tofloatsimplifies the kernel interface and is more efficient for GPU kernels when passing a single scalar value.
1658-1658: LGTM! Launch function signature updated consistently.The parameter type change aligns with the kernel signature update and maintains consistency throughout the call chain.
1779-1779: LGTM! Signature update maintains consistency.The parameter type change is consistent with the kernel signature and the other launch function (
launchMLA), ensuring a unified interface across all entry points.
yzh119
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, we can create a followup PR to also support device-side scale as well (for both xqa and trtllm-gen backend @yyihuang ).
β¦tion (#2084) <!-- .github/pull_request_template.md --> ## π Description - change `bmm1_scale` and `bmm2_scale` to `Union[float, torch.Tensor]`. notice that when using tensor, it must be applied by log2e - **remove the `bmm1_scale_log2_tensor` and `bmm2_scale_tensor` in the `xqa_batch_decode_with_kv_cache_mla`** - update trtllm-gen FMHA kernels TODO: do the same refactor for xqa kernels. The support for the device side scales was removed in #2033 ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Attention scale parameters now accept either floats or 1-element tensors across prefill, decode and runtime; tensor scales are validated and applied on-device and pointer-backed scale paths are supported. * **Chores** * Updated FMHA artifact path and checksum constants; added a public utility import and removed an obsolete inline comment. * **Tests** * Updated tests to exercise device/tensor-or-scalar scale flows, removed legacy per-tensor call-site args, and added device-scale parametrization for several test variants. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Siyuan Fu <[email protected]>
β¦tion (flashinfer-ai#2084) <!-- .github/pull_request_template.md --> - change `bmm1_scale` and `bmm2_scale` to `Union[float, torch.Tensor]`. notice that when using tensor, it must be applied by log2e - **remove the `bmm1_scale_log2_tensor` and `bmm2_scale_tensor` in the `xqa_batch_decode_with_kv_cache_mla`** - update trtllm-gen FMHA kernels TODO: do the same refactor for xqa kernels. The support for the device side scales was removed in flashinfer-ai#2033 <!-- Link any related issues here --> Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> * **New Features** * Attention scale parameters now accept either floats or 1-element tensors across prefill, decode and runtime; tensor scales are validated and applied on-device and pointer-backed scale paths are supported. * **Chores** * Updated FMHA artifact path and checksum constants; added a public utility import and removed an obsolete inline comment. * **Tests** * Updated tests to exercise device/tensor-or-scalar scale flows, removed legacy per-tensor call-site args, and added device-scale parametrization for several test variants. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Siyuan Fu <[email protected]>
π Description
π Related Issues
π Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
β Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.π§ͺ Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Breaking Changes
Documentation
Tests