Skip to content

Conversation

@qsang-nv
Copy link
Collaborator

@qsang-nv qsang-nv commented Nov 4, 2025

πŸ“Œ Description

πŸ” Related Issues

πŸš€ Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

βœ… Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

πŸ§ͺ Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Breaking Changes

    • Public xqa/xqa_mla entry points now accept kv_scale as a plain float (default 1.0) instead of a 1-element tensor. Update call sites accordingly.
  • Documentation

    • Docstrings updated to reflect kv_scale as float.
  • Tests

    • Tests updated to pass scalar kv_scale, with added parameterization and conditional skip for FP8 kv-cache scenarios.

Signed-off-by: Qidi Sang <[email protected]>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 4, 2025

Walkthrough

Converted kvCacheScale from a pointer/Tensor/TensorView to a plain scalar (float/double) across Python APIs, C++/CUDA bindings, kernel implementations, and tests; call sites and scale computations were updated to use direct scalar values instead of indexing into device memory or 1-element tensors.

Changes

Cohort / File(s) Summary
CUDA XQA Binding Layer
csrc/flashinfer_xqa_binding.cu
Updated public wrapper signatures xqa_wrapper_mla and xqa_wrapper to accept kvCacheScale as double instead of TensorView
XQA Wrapper
csrc/xqa/xqa_wrapper.cu
Changed xqa_wrapper_mla and xqa_wrapper to take double kvCacheScale; updated downstream launch calls to forward scalar value
MHA Kernel Implementation
csrc/xqa/mha.cu, csrc/xqa/mha.h, csrc/xqa/mha_sm90.cu
Replaced float const* __restrict__ kvCacheScale with float kvCacheScale in kernel and launch signatures; replaced kvCacheScale[0] usages with kvCacheScale and updated qk/vo/rcp scale calculations
MLA Kernel Implementation
csrc/xqa/mla_sm120.cu
Converted KernelArgs and kernel_mha / launchMLA / launchMLAFlashInfer to store/pass float kvCacheScale (value) and updated usages from kvCacheScale[0] to scalar access
Python Interface
flashinfer/xqa.py
Changed public functions xqa, _fake_xqa, xqa_mla, _fake_xqa_mla to accept kv_scale: float (default 1.0) instead of Optional[torch.Tensor]; removed None-handling that created a tensor
Python Caller & Tests
flashinfer/decode.py, tests/attention/test_xqa.py
Updated call sites to pass plain float for kv_scale/kv_cache_scale (removed 1-element tensor wrappers); updated tests to parameterize and pass scalar values, added conditional skip for unsupported FP8 kv cache cases

Sequence Diagram(s)

sequenceDiagram
    participant Py as Python
    participant PyLayer as flashinfer/xqa.py
    participant Binding as csrc/flashinfer_xqa_binding.cu
    participant Wrapper as csrc/xqa/xqa_wrapper.cu
    participant Kernel as CUDA kernels (mha/mla)

    Note over Py,PyLayer: Old: passed 1-element tensor
    Py->>PyLayer: xqa(..., kv_scale=torch.tensor([s]))
    PyLayer->>Binding: call binding with TensorView kvCacheScale
    Binding->>Wrapper: forward TensorView
    Wrapper->>Kernel: launch(..., kvCacheScale=ptr)
    Kernel->>Kernel: scale = kvCacheScale[0]

    Note over Py,PyLayer: New: pass scalar float
    Py->>PyLayer: xqa(..., kv_scale=1.0)
    PyLayer->>Binding: call binding with double kvCacheScale
    Binding->>Wrapper: forward double kvCacheScale
    Wrapper->>Kernel: launch(..., kvCacheScale=float)
    Kernel->>Kernel: scale = kvCacheScale
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Areas to check:
    • All kernel/launch signatures and call sites for consistent type changes (especially across sm90/sm120 variants)
    • Correct propagation from Python float β†’ C++ double/float without unintended conversions or precision issues
    • Tests and decode caller updates for matching runtime behavior when FP8/int8 KV cache is enabled

Possibly related PRs

Suggested reviewers

  • djmmoss
  • yongwww
  • yzh119
  • cyx-6
  • nvmbreughe

Poem

🐰
Hop by hop I checked the scale,
No more tensors on the trail,
One small float, so light and neat,
Kernels hum and tests repeat,
A tidy hop β€” the change is complete.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description is just the repository template with all sections empty or unchecked; no actual description, motivation, or changes are provided. Fill in the Description section explaining what was changed and why, and complete the PR Checklist items to confirm pre-commit checks and tests were performed.
βœ… Passed checks (1 passed)
Check name Status Explanation
Title check βœ… Passed The title accurately describes the main change: converting kv_scale parameter from tensor/pointer types to a scalar across the xqa codebase.
✨ Finishing touches
  • πŸ“ Generate docstrings
πŸ§ͺ Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❀️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request streamlines the handling of the kv_scale parameter within the XQA (eXtended Query Attention) implementation. By transitioning kv_scale from a single-element tensor or pointer to a direct scalar, the change simplifies the API and internal logic across both the C++ CUDA kernels and their Python interfaces. This refinement aims to improve code readability and reduce any overhead associated with treating a singular scaling factor as a more complex data structure.

Highlights

  • API Simplification: The kv_scale parameter, used for KV cache quantization, has been refactored from a single-element tensor or pointer to a direct scalar value (float/double) across the C++ CUDA kernels and Python bindings.
  • Code Clarity and Efficiency: This change removes the need for tensor creation and pointer dereferencing for kv_scale, leading to cleaner code and potentially minor performance improvements by treating the scaling factor as a primitive type.
  • Python Interface Update: Python functions like xqa and xqa_mla now directly accept a float for kv_scale, and the logic for handling None or creating a single-element tensor has been removed.
  • Test Adjustments: Corresponding test cases have been updated to pass kv_scale as a scalar float, aligning with the new API.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with πŸ‘ and πŸ‘Ž on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the kv_scale parameter to be a scalar value instead of a tensor across the XQA functionality. The changes are consistently applied to the CUDA kernels, C++ wrappers, and Python bindings. This is a good simplification that improves performance by avoiding a memory dereference in the CUDA kernels and makes the Python API cleaner. The code modifications are correct and the tests have been updated accordingly. This is a solid improvement to the codebase.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
csrc/xqa/mha.cu (1)

1304-1306: Refresh the kvCacheScale comment

Now that kvCacheScale is passed by value, the existing note about it being a β€œDevice memory scalar” is misleading. Please update (and mirror any similar spots) so future readers don’t assume they still need to manage a device allocation here.

csrc/xqa/mla_sm120.cu (1)

398-400: Tune the kvCacheScale comment

Same minor nit as in the MHA path: this argument is no longer a device pointer, so the comment should be adjusted accordingly (and any other duplicates cleaned up).

csrc/xqa/mha.h (1)

117-119: Sync comment with the new calling convention

This prototype (and the other overloads below) still advertises kvCacheScale as a β€œDevice memory scalar,” even though it’s now passed by value. Updating the wording here will keep the header consistent with the implementation.

πŸ“œ Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between 2d68a6b and b94cb61.

πŸ“’ Files selected for processing (9)
  • csrc/flashinfer_xqa_binding.cu (2 hunks)
  • csrc/xqa/mha.cu (6 hunks)
  • csrc/xqa/mha.h (5 hunks)
  • csrc/xqa/mha_sm90.cu (7 hunks)
  • csrc/xqa/mla_sm120.cu (6 hunks)
  • csrc/xqa/xqa_wrapper.cu (4 hunks)
  • flashinfer/decode.py (1 hunks)
  • flashinfer/xqa.py (8 hunks)
  • tests/attention/test_xqa.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/xqa/mha.h (1)
csrc/xqa/mla_sm120.cu (2)
  • launchMLA (1651-1761)
  • launchMLA (1651-1662)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

@bkryu
Copy link
Collaborator

bkryu commented Nov 4, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !108 has been created, and the CI pipeline #37853197 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #37853197: 13/17 passed

Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes LGTM to me; @yzh119 can you take a quick look to give approval?

csrc/xqa/mha.cu Outdated
uint32_t const batchSize,
float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V
// cache. Used only for int8/fp8 KV cache.
float kvCacheScale, // Device memory scalar. Same scale for K and V
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment could be updated, right? It's no longer a device memory.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

seq_len_list.fill_(seq_len)

kv_cache_scale = torch.ones(1, dtype=torch.float32, device="cuda")
kv_cache_scale = 1.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also test the case where the value isn’t 1.0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
csrc/xqa/mha_sm90.cu (1)

1315-1334: Defensive check for zero kvCacheScale.

rcpKScale/rcpVScale compute 1/kvCacheScale. If a caller passes 0, this will poison new‑token writes. Add a debug assert or early return.

- float const rcpKScale = 1.F / kvCacheScale;
+ assert(kvCacheScale != 0.f && "kvCacheScale must be non-zero");
+ float const rcpKScale = 1.F / kvCacheScale;
...
- float const rcpVScale = 1.F / kvCacheScale;
+ assert(kvCacheScale != 0.f && "kvCacheScale must be non-zero");
+ float const rcpVScale = 1.F / kvCacheScale;

Also applies to: 1373-1389

♻️ Duplicate comments (3)
tests/attention/test_xqa.py (1)

183-185: Good: parametrize kv_scale and q_scale (also covers non‑1.0).

This directly addresses the earlier ask to test values not equal to 1.0. Looks solid.

Also applies to: 451-453

csrc/xqa/mha.cu (2)

2386-2415: Host kernel wrappers: scalar kvCacheScale end‑to‑end.

Signatures and launches look aligned.

Past note about β€œdevice memory” wording for kv scale is resolved here.

Also applies to: 2443-2476, 2547-2551, 2563-2576, 2618-2621


2443-2551: No additional action β€” see SM90 comment for repo‑wide verification.

Avoiding duplicate scripts; rely on earlier verification to catch stale pointer/indexing.

Also applies to: 2563-2622

🧹 Nitpick comments (3)
tests/attention/test_xqa.py (3)

28-30: Guard CUDA device queries at import time.

Accessing CUDA props at module import can raise on non‑CUDA nodes before pytest marks are evaluated. Move behind a check or into tests.

- props = torch.cuda.get_device_properties(0)
- sm_count = props.multi_processor_count
+ if not torch.cuda.is_available():
+     pytest.skip("CUDA is not available", allow_module_level=True)
+ props = torch.cuda.get_device_properties(0)
+ sm_count = props.multi_processor_count

183-185: Optional: add an amplified scale (e.g., 2.0) to stress saturation paths.

Broader coverage can catch fp8 saturation/regression when scaling up.

-@pytest.mark.parametrize("kv_scale", [1.0, 0.5])
+@pytest.mark.parametrize("kv_scale", [1.0, 0.5, 2.0])
-@pytest.mark.parametrize("q_scale", [1.0, 0.5])
+@pytest.mark.parametrize("q_scale", [1.0, 0.5, 2.0])

Also applies to: 451-453


634-636: Replace hardcoded 576 with valid_elems_per_head_qk variable for maintainability.

The variable valid_elems_per_head_qk is in scope at line 634 and equals 576. Using the variable instead of the magic number prevents silent drift if head dimensions change and aligns with the reference function's use of valid_elems_per_head parameter.

-    q_scale=q_scale * math.sqrt(576),
+    q_scale=q_scale * math.sqrt(valid_elems_per_head_qk),
πŸ“œ Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between b94cb61 and 33db49e.

πŸ“’ Files selected for processing (5)
  • csrc/xqa/mha.cu (6 hunks)
  • csrc/xqa/mha.h (6 hunks)
  • csrc/xqa/mha_sm90.cu (7 hunks)
  • csrc/xqa/mla_sm120.cu (6 hunks)
  • tests/attention/test_xqa.py (8 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • csrc/xqa/mha.h
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
πŸ”‡ Additional comments (14)
tests/attention/test_xqa.py (1)

201-203: Skip non‑FP8 when kv_scale β‰  1 is correct.

Kernel applies kv scale only for quantized KV; skipping this matrix of cases is appropriate.

csrc/xqa/mha_sm90.cu (4)

606-637: Kernel takes kvCacheScale as a scalar now β€” good.

Signature update removes device reads and simplifies launches.


775-777: qkScale correctly gates kvCacheScale to quantized KV.

Using kvCacheScale only when cache is int8/fp8 matches the intended behavior.


964-965: xvoScale update is consistent with qkScale gating.

Scalar kvCacheScale applied only for quantized KV on the XV path.


2933-2934: Migration to scalar kvCacheScale verified successfully across all call sites.

All kernel declarations (lines 2933, 3045), kernel invocations (lines 3019, 3098), and internal kernel logic (arithmetic operations at lines 775, 964, 1318, 1377) consistently use float kvCacheScale as a scalar. Python bindings pass floats. No pointer-style usages, dereferences, or indexing remain.

csrc/xqa/mha.cu (3)

1270-1307: Kernel impl now takes scalar kvCacheScale β€” good cleanup.

Removes unnecessary device indirection; fits with the gating on quantized KV.


1505-1507: qkScale update matches SM90 logic.

Gated application of kvCacheScale is correct here too.


2158-2171: voScale includes kvCacheScale only for quantized KV.

Consistent and correct; preserves FP16/BF16 behavior.

csrc/xqa/mla_sm120.cu (6)

398-398: LGTM! Simplified parameter from reference to scalar.

The change from float const& to float is appropriate for a scalar value and eliminates unnecessary indirection. This aligns with the PR's objective to simplify the kvCacheScale interface across the codebase.


451-451: LGTM! Updated to direct scalar access.

The change from args.kvCacheScale[0] to args.kvCacheScale correctly reflects the parameter type change and simplifies the code.


1230-1230: LGTM! Scalar access updated correctly.

Both uses of kvCacheScale in the ternary expression have been correctly updated to direct scalar access.


1555-1555: LGTM! Kernel parameter simplified to scalar.

The change from float const* __restrict__ to float simplifies the kernel interface and is more efficient for GPU kernels when passing a single scalar value.


1658-1658: LGTM! Launch function signature updated consistently.

The parameter type change aligns with the kernel signature update and maintains consistency throughout the call chain.


1779-1779: LGTM! Signature update maintains consistency.

The parameter type change is consistent with the kernel signature and the other launch function (launchMLA), ensuring a unified interface across all entry points.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, we can create a followup PR to also support device-side scale as well (for both xqa and trtllm-gen backend @yyihuang ).

@yzh119 yzh119 enabled auto-merge (squash) November 5, 2025 04:42
@yzh119 yzh119 merged commit 6d19a75 into flashinfer-ai:main Nov 5, 2025
4 checks passed
jiahanc pushed a commit that referenced this pull request Nov 18, 2025
…tion (#2084)

<!-- .github/pull_request_template.md -->

## πŸ“Œ Description

- change `bmm1_scale` and `bmm2_scale` to `Union[float, torch.Tensor]`.
notice that when using tensor, it must be applied by log2e
- **remove the `bmm1_scale_log2_tensor` and `bmm2_scale_tensor` in the
`xqa_batch_decode_with_kv_cache_mla`**
- update trtllm-gen FMHA kernels

TODO: do the same refactor for xqa kernels. The support for the device
side scales was removed in #2033

## πŸ” Related Issues

<!-- Link any related issues here -->

## πŸš€ Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### βœ… Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## πŸ§ͺ Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Attention scale parameters now accept either floats or 1-element
tensors across prefill, decode and runtime; tensor scales are validated
and applied on-device and pointer-backed scale paths are supported.

* **Chores**
* Updated FMHA artifact path and checksum constants; added a public
utility import and removed an obsolete inline comment.

* **Tests**
* Updated tests to exercise device/tensor-or-scalar scale flows, removed
legacy per-tensor call-site args, and added device-scale parametrization
for several test variants.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Siyuan Fu <[email protected]>
qsang-nv pushed a commit to qsang-nv/flashinfer that referenced this pull request Nov 18, 2025
…tion (flashinfer-ai#2084)

<!-- .github/pull_request_template.md -->

- change `bmm1_scale` and `bmm2_scale` to `Union[float, torch.Tensor]`.
notice that when using tensor, it must be applied by log2e
- **remove the `bmm1_scale_log2_tensor` and `bmm2_scale_tensor` in the
`xqa_batch_decode_with_kv_cache_mla`**
- update trtllm-gen FMHA kernels

TODO: do the same refactor for xqa kernels. The support for the device
side scales was removed in flashinfer-ai#2033

<!-- Link any related issues here -->

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Attention scale parameters now accept either floats or 1-element
tensors across prefill, decode and runtime; tensor scales are validated
and applied on-device and pointer-backed scale paths are supported.

* **Chores**
* Updated FMHA artifact path and checksum constants; added a public
utility import and removed an obsolete inline comment.

* **Tests**
* Updated tests to exercise device/tensor-or-scalar scale flows, removed
legacy per-tensor call-site args, and added device-scale parametrization
for several test variants.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Siyuan Fu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants