Skip to content

Conversation

@qsang-nv
Copy link
Collaborator

@qsang-nv qsang-nv commented Nov 21, 2025

📌 Description

WIP. Do not merge, see if this could fix xqa flaky test.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Tests

    • Default test seed changed to improve reproducibility; tests now use batched K/V handling, batched reference comparisons, expanded sequence-length cases, device-based scaling tensors, seeded shuffling, and batch-level validation with adjusted tolerances.
    • Over-provisioned GPU runs now skip instead of failing.
  • Bug Fixes

    • More consistent attention scaling and more robust GPU attention validation across batched and device-based test paths.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Qidi Sang <[email protected]>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 21, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Tests and CUDA kernels changed: test_xqa now defaults to seed 0, uses a seeded torch.Generator, assembles batched K/V caches and compares batched references; q_scale/kv_scale passed as device tensors; MHA/MLA kernels now read single-element scale pointers via array indexing (ptr[0]); one comm test now skips when world_size > GPUs.

Changes

Cohort / File(s) Change Summary
XQA test refactor
tests/attention/test_xqa.py
Default seed changed to 0; RNG moved to a seeded torch.Generator; random permutations use the generator; seq_len options updated; q_scale/kv_scale passed as CUDA tensors; batched K/V cache assembly and batched reference attention; added torch.cuda.synchronize() and semaphore .zero_() around launches; adjusted tolerances and pass-ratio checks.
MHA / MLA kernels (scale access)
csrc/xqa/mha.cu, csrc/xqa/mha_sm90.cu, csrc/xqa/mla_sm120.cu
Read single-element scale pointers with array indexing (qScalePtr[0], kvScalePtr[0]) instead of pointer dereference (*qScalePtr, *kvScalePtr); semantics preserved; no public API signature changes.
Comm test behavior
tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py
Replace raising ValueError when requested world_size exceeds available GPUs with pytest.skip(...), turning a hard failure into a test skip.

Sequence Diagram(s)

sequenceDiagram
  participant Test as tests/test_xqa.py
  participant CUDA as CUDA runtime
  participant Kernel as xqa MHA kernel

  Note over Test: prepare seeded Generator\ncreate Q/K/V tensors on device\nassemble batched K/V caches
  Test->>CUDA: seed CPU & CUDA RNGs (seed=0)
  Test->>CUDA: generate randperm using Generator
  Test->>CUDA: semaphores.zero_()
  Test->>CUDA: torch.cuda.synchronize()
  Test->>Kernel: launch xqa/xqa_mla (device tensors, q/kv scales as device tensors -> ptr)
  Kernel->>Kernel: load scales via ptr[0]
  Kernel-->>CUDA: kernel completes
  Test->>CUDA: torch.cuda.synchronize()
  Test-->>Test: batched comparison of kernel output vs ref_attention
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • Check safe use of ptr[0] (ensure non-null device memory and proper lifetime).
  • Verify device-tensor-backed q/kv scale pointers are passed correctly to kernels.
  • Review batched K/V assembly and batched reference for shape/stride/layout correctness.
  • Confirm added synchronizations and semaphore resets don't introduce ordering issues.

Possibly related PRs

Suggested reviewers

  • cyx-6
  • yzh119
  • nvmbreughe
  • wenscarl

Poem

"I hopped through buffers, tidy and neat,
seeded the RNG for a steady beat.
Scales read by index, no more a deref fuss,
semaphores cleared — kernels hum with us.
🐇✨ Small hops, reproducible joy."

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ❓ Inconclusive The PR description is marked 'WIP. Do not merge' and lacks concrete details about what was changed, why it fixes the flakiness, or technical implementation specifics required for review. Remove WIP status and provide concrete details: describe the root cause of flakiness, explain which changes fix it, clarify test parameterization changes, and indicate when the PR is ready for review.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title 'fix flaky xqa test' directly relates to the main objective of the PR, which involves changes to xqa test files and kernel implementations to address flakiness.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ea51957 and 3915e37.

📒 Files selected for processing (1)
  • tests/attention/test_xqa.py (8 hunks)
🧰 Additional context used
🪛 Ruff (0.14.5)
tests/attention/test_xqa.py

60-60: Unpacked variable batch_size is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


60-60: Unpacked variable nb_k_heads is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


60-60: Unpacked variable head_grp_size is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (10)
tests/attention/test_xqa.py (10)

11-11: LGTM: Seed default changed to 0.

The change from seed=42 to seed=0 is consistent with the goal of fixing flaky tests.


34-120: LGTM: Batched reference implementation.

The refactored reference attention function now correctly handles batched inputs. The docstring clearly documents the expected tensor shapes, and all operations (Q·K^T, softmax, attention sinks, sliding window) properly account for the batch dimension.


131-146: LGTM: Expanded seq_len coverage with xfail for flaky case.

The expanded parameterization includes more edge cases (512, 514). Marking seq_len=514 with xfail(strict=False) is appropriate for tracking a known flaky case while the PR investigates the root cause.


276-278: LGTM: Deterministic page shuffling.

Using a seeded CUDA generator for torch.randperm ensures deterministic page index shuffling, which is critical for fixing flaky tests.


357-358: LGTM: Scales passed as CUDA tensors.

Wrapping q_scale and kv_scale in torch.tensor(..., device="cuda") aligns with the kernel changes that expect tensor-based scales.


366-407: LGTM: Batch reconstruction of paged K/V caches.

The logic correctly reconstructs contiguous K/V caches from paged memory for all batches. The implementation properly handles both NHD and HND layouts, and the reshaping operations are correct.


409-465: LGTM: Batched validation with appropriate tolerances.

The batch-level validation approach is sound:

  • Correctly reshapes both reference and kernel outputs to matching shapes
  • Sets appropriate tolerances for FP8 and non-FP8 cases
  • Uses 99% pass ratio for element-wise validation, which is reasonable given numerical precision limitations

486-486: LGTM: Consistent seed initialization.

The MLA test correctly uses set_random_seed(0), consistent with the main xqa test.


611-646: LGTM: Batch reconstruction in MLA test.

The batch reconstruction logic for the MLA test correctly assembles K/V caches from paged memory. The implementation properly handles the NHD layout (the only layout supported for MLA).


648-695: LGTM: Batched validation for MLA test.

The batch-level validation for the MLA test is correctly implemented:

  • Properly adjusts q_scale by sqrt(576) for MLA attention
  • Reshapes outputs to matching shapes
  • Uses appropriate tolerances (0.05/0.05)
  • Sets pass ratio to 95%, which is reasonable given MLA's increased complexity

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces several targeted changes to the test_xqa.py file, primarily focused on enhancing the determinism and synchronization of CUDA operations. The modifications aim to address and resolve observed flakiness in the xqa test by ensuring consistent random number generation, explicit device tensor handling, and proper GPU synchronization points, thereby improving the reliability of the test suite.

Highlights

  • CUDA Random Seed Initialization: Added torch.cuda.random.manual_seed_all(seed) to the set_random_seed function to ensure all CUDA random number generators are initialized with a consistent seed, improving test reproducibility.
  • Deterministic Page Index Shuffling: Modified the torch.randperm call for shuffling page indices to use a torch.Generator with a fixed seed (42), making the page index generation deterministic.
  • CUDA Synchronization and Semaphore Reset: Introduced torch.cuda.synchronize() and semaphores.zero_() before the xqa function call to ensure all previous CUDA operations are complete and semaphores are reset, preventing potential race conditions.
  • CUDA Tensor Conversion and Post-Call Synchronization: Converted q_scale and kv_scale arguments to xqa into explicit torch.tensor objects on the 'cuda' device and added another torch.cuda.synchronize() call after xqa to ensure all GPU computations are finished before subsequent checks.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix a flaky test in test_xqa.py by improving determinism and synchronization. The changes include adding explicit random number generator seeding, using CUDA tensors for scale parameters, and adding torch.cuda.synchronize() calls to ensure kernel completion. My review identifies a redundant random seed call and suggests using a constant for the hardcoded seed to improve maintainability. Overall, the changes are in the right direction to improve test stability.

np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.cuda.random.manual_seed_all(seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function set_random_seed already calls torch.cuda.manual_seed_all(seed) on line 14. The added line torch.cuda.random.manual_seed_all(seed) is a duplicate, as torch.cuda.random.manual_seed_all is an alias for torch.cuda.manual_seed_all. This redundant call should be removed.

flattened = page_list_arg.flatten()
indices = torch.randperm(flattened.numel(), device="cuda")
generator = torch.Generator(device="cuda")
generator.manual_seed(42)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The seed 42 is hardcoded here. It's also used in the set_random_seed(42) call at the beginning of this test on line 177. To improve maintainability and avoid magic numbers, consider defining a constant for the seed (e.g., SEED = 42) at the module level and using it in both places.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (5)
tests/attention/test_xqa.py (5)

11-18: CUDA RNG seeding looks redundant; consider simplifying for clarity

set_random_seed now calls three CUDA-specific seeders (torch.cuda.manual_seed, torch.cuda.manual_seed_all, and torch.cuda.random.manual_seed_all). They likely hit the same RNG state, so the extra call doesn’t buy more determinism and may just be confusing. Consider standardizing on a single canonical CUDA seeding pattern (e.g., torch.manual_seed + torch.cuda.manual_seed_all) and dropping the duplicate, unless you’ve verified a specific PyTorch-version quirk that requires this.


270-276: Using a dedicated CUDA Generator for randperm is a solid determinism fix

The switch to a local torch.Generator(device="cuda") with a fixed seed for randperm decouples this shuffle from global RNG state and should help with test flakiness.

You might also consider mirroring this pattern in test_xqa_mla’s page shuffling (currently still using bare torch.randperm on CUDA) to keep both XQA paths equally reproducible and isolated from other tests’ RNG usage.


331-343: Semaphore re‑zeroing and pre‑call synchronize are likely unnecessary

Here semaphores is just created via torch.zeros a few lines above, so calling semaphores.zero_() again doesn’t change observable state. Unless you expect to reuse this tensor across multiple xqa invocations in the future, you can probably drop the extra zeroing.

Similarly, the torch.cuda.synchronize() before the call is conservative; if all the setup ops live on the same stream as xqa, ordering is already enforced. It’s fine to keep for safety, but a brief comment explaining why it’s needed (e.g., guarding against non-default streams inside the extension) would help prevent future “cleanup” from reintroducing flakiness.


344-363: Device-tensor scales for xqa look fine; verify API expectations and consistency with xqa_mla

Passing q_scale/kv_scale as 0‑D CUDA tensors should be compatible if xqa now expects device scalars. Two things to double‑check:

  • Confirm that the underlying xqa binding indeed supports/expects torch.Tensor inputs here (not just Python floats), especially across all call sites.
  • Decide whether xqa_mla should follow the same convention; right now it still receives Python float scales, so if the C++ side expects tensors for both kernels, you may want to mirror this conversion there as well.

If the APIs are intentionally asymmetric, adding a short comment around this call noting the expected dtypes/devices for these scale parameters would reduce future confusion.


365-366: Post‑xqa CUDA synchronize is appropriate to avoid async read issues

Adding torch.cuda.synchronize() after xqa ensures all device work has completed before the test inspects output, which is important if the extension uses non‑default streams or otherwise doesn’t synchronize with PyTorch’s usual tensors-to-host transitions.

If you’ve seen similar flakiness in test_xqa_mla, consider bracketing the xqa_mla call with the same pattern for symmetry and robustness.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2628beb and 119487e.

📒 Files selected for processing (1)
  • tests/attention/test_xqa.py (4 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

@yzh119
Copy link
Collaborator

yzh119 commented Nov 21, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !155 has been created, and the CI pipeline #38955732 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #38955732: 14/18 passed

np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.cuda.random.manual_seed_all(seed)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In recent pytorch versions,torch.manual_seed(seed) should cover the semantics of torch.cuda.random.manual_seed_all and there is no need to set gpu seed explicitly: https://docs.pytorch.org/docs/stable/notes/randomness.html#pytorch-random-number-generator

Signed-off-by: Qidi Sang <[email protected]>
@qsang-nv qsang-nv requested a review from cyx-6 as a code owner November 22, 2025 08:00
@yzh119
Copy link
Collaborator

yzh119 commented Nov 22, 2025

/bot run

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 119487e and 7c33f61.

📒 Files selected for processing (4)
  • csrc/xqa/mha.cu (1 hunks)
  • csrc/xqa/mha_sm90.cu (1 hunks)
  • csrc/xqa/mla_sm120.cu (1 hunks)
  • tests/attention/test_xqa.py (6 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (5)
csrc/xqa/mha.cu (1)

1313-1314: LGTM: Pointer dereference replaced with array indexing.

The change from *qScalePtr to qScalePtr[0] (and similarly for kvScalePtr) is semantically equivalent and generates identical code. The null checks and fallback logic are correctly preserved.

csrc/xqa/mha_sm90.cu (1)

643-644: LGTM: Consistent refactor across kernel files.

The change matches the pattern in mha.cu, maintaining consistency across the codebase. The transformation from pointer dereference to array indexing is correct and preserves all null-checking logic.

csrc/xqa/mla_sm120.cu (1)

1567-1568: LGTM – Pointer access changed to array indexing for consistency.

The change from dereferencing (*qScalePtr, *kvScalePtr) to array indexing (qScalePtr[0], kvScalePtr[0]) is semantically equivalent and aligns with similar updates in related files.

tests/attention/test_xqa.py (2)

11-11: LGTM – Default seed updated.

The default seed change from 42 to 0 is cosmetic since all test invocations explicitly pass the seed value.


364-365: Post-call synchronization added for xqa test.

The torch.cuda.synchronize() after the xqa call ensures kernel completion before validation. However, the xqa_mla test is missing this post-call synchronization as well (see earlier comment).

Ensure the MLA test also includes post-call synchronization:

# Add after the xqa_mla call (after line 579)
torch.cuda.synchronize()

Comment on lines +271 to +273
generator = torch.Generator(device="cuda")
generator.manual_seed(42)
indices = torch.randperm(flattened.numel(), generator=generator, device="cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

Verify determinism in the MLA test.

The xqa test now uses a seeded generator for torch.randperm (good for determinism), but the xqa_mla test at line 522 still uses the default CUDA RNG without an explicit generator. This inconsistency could lead to non-deterministic behavior in the MLA test.


🏁 Script executed:

#!/bin/bash
# Verify that xqa_mla test uses deterministic RNG for randperm
rg -n -A2 -B2 'def test_xqa_mla' tests/attention/test_xqa.py
rg -n 'torch\.randperm' tests/attention/test_xqa.py

Length of output: 322


Add seeded generator to test_xqa_mla for deterministic RNG.

The xqa test uses a seeded generator for torch.randperm (line 273), but the xqa_mla test at line 522 still calls torch.randperm without an explicit generator, falling back to the default CUDA RNG. Apply the same generator pattern to test_xqa_mla to ensure deterministic behavior across both tests.

🤖 Prompt for AI Agents
In tests/attention/test_xqa.py around lines 271-273 and at the xqa_mla call near
line 522, the CUDA RNG is not consistently seeded; create a seeded CUDA
generator (generator = torch.Generator(device="cuda");
generator.manual_seed(42)) and pass it into torch.randperm as
generator=generator (keeping device="cuda") in the xqa_mla test so both tests
use the same deterministic RNG source.

Comment on lines 340 to 342
torch.cuda.synchronize()
semaphores.zero_()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Critical: MLA test missing synchronization.

The xqa test now includes torch.cuda.synchronize() and semaphores.zero_() before the kernel call—critical additions for ensuring proper ordering and clean state. However, the xqa_mla test (starting line 565) does not include these synchronization calls. Given that this PR aims to fix flaky xqa tests, the missing synchronization in the MLA test is a significant oversight that could cause flakiness.

Apply similar synchronization to the MLA test:

# Add before line 565 (before xqa_mla call)
torch.cuda.synchronize()
semaphores.zero_()
🤖 Prompt for AI Agents
In tests/attention/test_xqa.py around lines 340-342 and specifically for the
xqa_mla test starting at line 565, the MLA variant is missing the GPU
synchronization and semaphore reset that were added for the xqa test; before
calling xqa_mla at ~line 565 add a torch.cuda.synchronize() call followed by
semaphores.zero_() (using the same semaphores variable used elsewhere) to ensure
proper ordering and a clean semaphore state before launching the kernel.

Comment on lines +355 to +356
q_scale=torch.tensor(q_scale, device="cuda"),
kv_scale=torch.tensor(kv_cache_scale, device="cuda"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: MLA test not updated with tensor scales.

The xqa test now passes q_scale and kv_scale as CUDA tensors, aligning with the kernel changes (array indexing in mla_sm120.cu). However, the xqa_mla test at lines 575-576 still passes these as Python scalars. This inconsistency could cause runtime errors or incorrect behavior in the MLA path.

Update the xqa_mla test to use tensor scales:

# Update lines 575-576 in xqa_mla call
q_scale=torch.tensor(q_scale, device="cuda"),
kv_scale=torch.tensor(kv_cache_scale, device="cuda"),
🤖 Prompt for AI Agents
In tests/attention/test_xqa.py around lines 575 to 576, the xqa_mla test still
passes q_scale and kv_scale as Python scalars while the rest of the tests (and
kernel changes) expect CUDA tensors; update the xqa_mla call to wrap both scales
with torch.tensor(..., device="cuda") so q_scale and kv_scale are passed as CUDA
tensors (matching the change at lines 355-356 and preventing MLA path
runtime/type errors).

@flashinfer-bot
Copy link
Collaborator

GitLab MR !155 has been updated with latest changes, and the CI pipeline #38985166 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #38985166: 13/18 passed

Signed-off-by: Qidi Sang <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/attention/test_xqa.py (1)

34-120: Address Ruff warning for unused unpacked shape variables in ref_attention

At Line 60, batch_size, nb_k_heads, and head_grp_size are unpacked from q.shape but never used, which Ruff flags (RUF059). This is likely what’s breaking the Ruff job in CI.

You can keep the shape annotation while silencing the warning by prefixing with underscores:

-    batch_size, nb_k_heads, head_grp_size, _ = q.shape
+    _batch_size, _nb_k_heads, _head_grp_size, _ = q.shape

Alternatively, you can drop the unpack entirely if you don’t need it.

♻️ Duplicate comments (1)
tests/attention/test_xqa.py (1)

476-599: Align xqa_mla scale arguments with the tensor‑based API

In test_xqa_mla you still pass q_scale and kv_cache_scale as Python scalars:

xqa_mla(
    ...
    q_scale=q_scale,
    kv_scale=kv_cache_scale,
    ...
)

Whereas test_xqa now passes these as CUDA tensors:

q_scale=torch.tensor(q_scale, device="cuda"),
kv_scale=torch.tensor(kv_cache_scale, device="cuda"),

Given the kernel and API changes elsewhere, this inconsistency is likely to exercise a different path (or fail at runtime) for MLA versus non‑MLA, and it echoes a prior review comment on this test. I’d strongly suggest making MLA match the tensor convention:

-    xqa_mla(
-        q_heads.to(torch.float8_e4m3fn),
-        cache_k_heads.to(torch.float8_e4m3fn),
-        cache_v_heads.to(torch.float8_e4m3fn),
-        page_list_arg,
-        seq_len_list,
-        output,
-        scratch_buf,
-        semaphores,
-        tokens_per_page,
-        q_scale=q_scale,
-        kv_scale=kv_cache_scale,
-        sm_count=sm_count,
-        enable_pdl=enable_pdl,
-    )
+    xqa_mla(
+        q_heads.to(torch.float8_e4m3fn),
+        cache_k_heads.to(torch.float8_e4m3fn),
+        cache_v_heads.to(torch.float8_e4m3fn),
+        page_list_arg,
+        seq_len_list,
+        output,
+        scratch_buf,
+        semaphores,
+        tokens_per_page,
+        q_scale=torch.tensor(q_scale, device="cuda"),
+        kv_scale=torch.tensor(kv_cache_scale, device="cuda"),
+        sm_count=sm_count,
+        enable_pdl=enable_pdl,
+    )

This keeps the MLA test in lockstep with the main XQA test and the underlying kernel expectations.

🧹 Nitpick comments (2)
tests/attention/test_xqa.py (2)

11-17: Simplify set_random_seed to rely on torch.manual_seed only

You’re calling torch.manual_seed, torch.cuda.manual_seed, and torch.cuda.manual_seed_all for the same seed. On recent PyTorch versions torch.manual_seed(seed) already seeds CPU and all CUDA RNGs, so the extra CUDA seeding is redundant and makes the behavior slightly harder to reason about.

You can simplify to a single call:

-def set_random_seed(seed=0):
-    torch.manual_seed(seed)
-    torch.cuda.manual_seed(seed)
-    torch.cuda.manual_seed_all(seed)
-    np.random.seed(seed)
-    torch.backends.cudnn.deterministic = True
-    torch.backends.cudnn.benchmark = False
+def set_random_seed(seed: int = 0) -> None:
+    # Seeds CPU and all CUDA RNGs on recent PyTorch versions
+    torch.manual_seed(seed)
+    np.random.seed(seed)
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.benchmark = False

259-263: Good use of a dedicated CUDA Generator for page shuffling

Using a local torch.Generator(device="cuda") with an explicit seed for torch.randperm decouples the page shuffling from the global RNG state and makes this part of the test reproducible even if other code touches the global generator.

If you prefer, you could factor the 42 into a small module‑level constant (e.g., PAGE_SHUFFLE_SEED = 42) to avoid a magic number, but the current behavior is otherwise sound.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7c33f61 and 7504aa2.

📒 Files selected for processing (1)
  • tests/attention/test_xqa.py (8 hunks)
🧰 Additional context used
🪛 Ruff (0.14.5)
tests/attention/test_xqa.py

60-60: Unpacked variable batch_size is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


60-60: Unpacked variable nb_k_heads is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


60-60: Unpacked variable head_grp_size is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
tests/attention/test_xqa.py (2)

330-455: Batched reference path and tolerance check in test_xqa look consistent

The new sequence reconstruction and batched ref_attention comparison for test_xqa look dimensionally and numerically consistent:

  • Reconstructing batch_k_cache / batch_v_cache from page_list_arg covers exactly max_seq_len tokens (num_pages * tokens_per_page == max_seq_len).
  • Reshaping q_heads/output to [batch_size, nb_k_heads, head_grp_size, dim] matches the documented ref_attention interface.
  • The tolerance logic (atol/rtol depending on FP8 KV and FP8 output) plus a pass_ratio threshold is a reasonable way to reduce per‑element flakiness while still catching systematic issues.

No functional issues stand out in this block.


601-685: MLA batched reconstruction and reference comparison look correct

The MLA‑specific reconstruction and batched ref_attention comparison largely mirror the non‑MLA path and look coherent:

  • batch_k_cache/batch_v_cache shapes match [batch_size, nb_k_heads, max_seq_len, valid_elems_per_head_qk], with ref_attention truncating V to valid_elems_per_head_v via valid_elems_per_v_head.
  • The q_scale=q_scale * math.sqrt(576) adjustment correctly produces qk_scale = q_scale * kv_cache_scale inside ref_attention for valid_elems_per_head_qk=576.
  • The final reshape of output to [batch_size, nb_k_heads, head_grp_size, valid_elems_per_head_v] and the 0.95 pass‑ratio threshold are consistent with the MLA tolerances you’ve chosen.

Overall this block looks sound from a correctness and determinism standpoint.

@yzh119
Copy link
Collaborator

yzh119 commented Nov 24, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !155 has been updated with latest changes, and the CI pipeline #39054334 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
tests/attention/test_xqa.py (1)

453-680: MLA test structure is sound; double-check q_scale/kv_scale API consistency and scaling

The MLA test mirrors the non‑MLA path nicely — same page shuffling, zeroing, batched reconstruction, and batchwise tolerance check — and the dimension handling for Q/K (576) vs V (512) looks correct. Two points worth verifying:

  • q_scale / kv_scale API consistency with xqa_mla

    • In test_xqa, you now pass q_scale and kv_scale as 1‑element CUDA tensors into xqa.
    • In test_xqa_mla, you still pass them as Python floats into xqa_mla.
      If the underlying xqa_mla binding has been updated to expect device tensors (similar to xqa), this mismatch could cause runtime issues or subtle differences vs the reference. If not, it’s still a bit surprising for the two APIs to diverge. Consider either:
    • updating the MLA call to q_scale=torch.tensor(q_scale, device="cuda") and kv_scale=torch.tensor(kv_cache_scale, device="cuda"), or
    • documenting that xqa_mla intentionally still accepts scalars while xqa takes tensors.

    This concern overlaps with an earlier review note about MLA scale handling, so please treat this as a follow‑up sanity check rather than a new issue.

  • Reference scaling factor for MLA
    In the reference call you pass q_scale=q_scale * math.sqrt(576) while ref_attention computes
    qk_scale = q_scale * kv_scale / sqrt(valid_elems_per_head) with valid_elems_per_head=576.
    That reduces to qk_scale = q_scale * kv_scale, which matches a kernel that uses qScale * kvCacheScale without an internal 1/sqrt(d) factor. This is plausible, but it’s subtle; if you continue to see MLA‑specific discrepancies, this scaling is a good place to double‑check against the kernel’s definition.

Everything else in this test — reconstruction, shapes, tolerances (atol = rtol = 0.05, ≥95% pass ratio), and the use of valid_elems_per_v_head to slice V down to 512 — looks coherent and numerically reasonable for an FP8 MLA path.

🧹 Nitpick comments (1)
tests/attention/test_xqa.py (1)

34-120: Tighten ref_attention API docs and unused shape unpack

The batched reference implementation looks consistent (shapes, sliding window masking, optional attention sinks, and separate K/V head dims all line up), but two small cleanups would help:

  • Line 60: batch_size, nb_k_heads, head_grp_size, _ = q.shape — none of these locals are used later. Either drop the unpack entirely or rename to _batch_size, _nb_k_heads, _head_grp_size to satisfy linters like Ruff (RUF059).
  • Docstring says seq_len can be a scalar or [batch_size] tensor, but the implementation treats it as a single scalar (:seq_len slicing and scalar comparisons). If per-batch lengths aren’t actually needed here, consider updating the docstring to just describe the scalar case to avoid confusion.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between da8975a and ea51957.

📒 Files selected for processing (1)
  • tests/attention/test_xqa.py (8 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_xqa.py (3)
flashinfer/fused_moe/utils.py (1)
  • _ (157-163)
csrc/xqa/mla_sm120.cu (3)
  • x (455-457)
  • x (459-461)
  • x (1032-1037)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • enable_pdl (220-220)
🪛 Ruff (0.14.5)
tests/attention/test_xqa.py

60-60: Unpacked variable batch_size is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


60-60: Unpacked variable nb_k_heads is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


60-60: Unpacked variable head_grp_size is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
tests/attention/test_xqa.py (2)

11-17: Seed helper centralizes determinism settings appropriately

set_random_seed now cleanly centralizes PyTorch and NumPy seeding plus cuDNN determinism flags. Given both tests call this at the top, this is a good place to control reproducibility; no issues from a correctness standpoint.


123-451: xqa test refactor (seeding, page shuffling, batched KV reconstruction, and batch validation) looks correct

Overall this test path looks solid and should help reduce flakiness:

  • Seeding: set_random_seed(0) at the start plus fixed-shape initializations makes the stochastic parts reproducible across runs.
  • Page shuffling: using page_list_arg with a CUDA randperm on the flattened indices, then reshaping back, yields a valid 1‑to‑1 page permutation per batch. Combined with the subsequent zeroing logic, the mapping between logical sequence positions and pages is correct for both full and partially filled last pages.
  • Zeroing unused cache positions: the start_page, token_start_in_first_page, and pages_to_zero logic correctly zeros:
    • only the tail of the first “partial” page, and
    • all subsequent pages,
      for both NHD and HND layouts, ensuring padded tokens don’t leak into the reference.
  • Batched cache reconstruction: batch_k_cache / batch_v_cache reconstruction via pages = page_list_arg[req, :num_pages] and reshape to contiguous [num_pages * tokens_per_page, head_dim] matches how seq_len is later interpreted in ref_attention (only the first seq_len positions are actually consumed).
  • Q reshaping: q_heads.squeeze(1).reshape(batch_size, nb_k_heads, head_grp_size, valid_elems_per_head) is dimensionally consistent with nb_q_heads = nb_k_heads * head_grp_size and matches the reference’s expectations.
  • Scale and validation:
    • Passing q_scale / kv_cache_scale as 1‑element CUDA tensors into xqa is consistent with pointer‑based scale handling on the kernel side.
    • The batch-level comparison (diff_abs/diff_rel, combined tolerance mask, and ≥99% pass ratio) is a good single-shot criterion that should be robust against small FP8 / mixed-precision noise while still sensitive to real regressions.

I don’t see correctness issues in this refactored test path; it’s a substantial improvement in clarity and robustness over per‑element checks.

Signed-off-by: Qidi Sang <[email protected]>
@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #39054334: 15/18 passed

@yzh119
Copy link
Collaborator

yzh119 commented Nov 24, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !155 has been updated with latest changes, and the CI pipeline #39089850 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #39089850: 15/18 passed

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, per discussion with @qsang-nv , we found there will be UT errors when seq_len=514 on spark, current workaround is to report xfail, should investigate later.


float const qScaleValue = qScalePtr != nullptr ? *qScalePtr : qScale;
float const kvCacheScaleValue = kvScalePtr != nullptr ? *kvScalePtr : kvCacheScale;
float const qScaleValue = qScalePtr != nullptr ? qScalePtr[0] : qScale;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these changes matter but wouldn't hurt as well.

@yzh119 yzh119 merged commit efd8554 into flashinfer-ai:main Nov 25, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants