-
Notifications
You must be signed in to change notification settings - Fork 581
fix flaky xqa test #2126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix flaky xqa test #2126
Conversation
Signed-off-by: Qidi Sang <[email protected]>
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughTests and CUDA kernels changed: test_xqa now defaults to seed 0, uses a seeded torch.Generator, assembles batched K/V caches and compares batched references; q_scale/kv_scale passed as device tensors; MHA/MLA kernels now read single-element scale pointers via array indexing ( Changes
Sequence Diagram(s)sequenceDiagram
participant Test as tests/test_xqa.py
participant CUDA as CUDA runtime
participant Kernel as xqa MHA kernel
Note over Test: prepare seeded Generator\ncreate Q/K/V tensors on device\nassemble batched K/V caches
Test->>CUDA: seed CPU & CUDA RNGs (seed=0)
Test->>CUDA: generate randperm using Generator
Test->>CUDA: semaphores.zero_()
Test->>CUDA: torch.cuda.synchronize()
Test->>Kernel: launch xqa/xqa_mla (device tensors, q/kv scales as device tensors -> ptr)
Kernel->>Kernel: load scales via ptr[0]
Kernel-->>CUDA: kernel completes
Test->>CUDA: torch.cuda.synchronize()
Test-->>Test: batched comparison of kernel output vs ref_attention
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🪛 Ruff (0.14.5)tests/attention/test_xqa.py60-60: Unpacked variable Prefix it with an underscore or any other dummy variable pattern (RUF059) 60-60: Unpacked variable Prefix it with an underscore or any other dummy variable pattern (RUF059) 60-60: Unpacked variable Prefix it with an underscore or any other dummy variable pattern (RUF059) ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (10)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces several targeted changes to the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to fix a flaky test in test_xqa.py by improving determinism and synchronization. The changes include adding explicit random number generator seeding, using CUDA tensors for scale parameters, and adding torch.cuda.synchronize() calls to ensure kernel completion. My review identifies a redundant random seed call and suggests using a constant for the hardcoded seed to improve maintainability. Overall, the changes are in the right direction to improve test stability.
tests/attention/test_xqa.py
Outdated
| np.random.seed(seed) | ||
| torch.backends.cudnn.deterministic = True | ||
| torch.backends.cudnn.benchmark = False | ||
| torch.cuda.random.manual_seed_all(seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| flattened = page_list_arg.flatten() | ||
| indices = torch.randperm(flattened.numel(), device="cuda") | ||
| generator = torch.Generator(device="cuda") | ||
| generator.manual_seed(42) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (5)
tests/attention/test_xqa.py (5)
11-18: CUDA RNG seeding looks redundant; consider simplifying for clarity
set_random_seednow calls three CUDA-specific seeders (torch.cuda.manual_seed,torch.cuda.manual_seed_all, andtorch.cuda.random.manual_seed_all). They likely hit the same RNG state, so the extra call doesn’t buy more determinism and may just be confusing. Consider standardizing on a single canonical CUDA seeding pattern (e.g.,torch.manual_seed+torch.cuda.manual_seed_all) and dropping the duplicate, unless you’ve verified a specific PyTorch-version quirk that requires this.
270-276: Using a dedicated CUDA Generator forrandpermis a solid determinism fixThe switch to a local
torch.Generator(device="cuda")with a fixed seed forrandpermdecouples this shuffle from global RNG state and should help with test flakiness.You might also consider mirroring this pattern in
test_xqa_mla’s page shuffling (currently still using baretorch.randpermon CUDA) to keep both XQA paths equally reproducible and isolated from other tests’ RNG usage.
331-343: Semaphore re‑zeroing and pre‑call synchronize are likely unnecessaryHere
semaphoresis just created viatorch.zerosa few lines above, so callingsemaphores.zero_()again doesn’t change observable state. Unless you expect to reuse this tensor across multiplexqainvocations in the future, you can probably drop the extra zeroing.Similarly, the
torch.cuda.synchronize()before the call is conservative; if all the setup ops live on the same stream asxqa, ordering is already enforced. It’s fine to keep for safety, but a brief comment explaining why it’s needed (e.g., guarding against non-default streams inside the extension) would help prevent future “cleanup” from reintroducing flakiness.
344-363: Device-tensor scales forxqalook fine; verify API expectations and consistency withxqa_mlaPassing
q_scale/kv_scaleas 0‑D CUDA tensors should be compatible ifxqanow expects device scalars. Two things to double‑check:
- Confirm that the underlying
xqabinding indeed supports/expectstorch.Tensorinputs here (not just Python floats), especially across all call sites.- Decide whether
xqa_mlashould follow the same convention; right now it still receives Python float scales, so if the C++ side expects tensors for both kernels, you may want to mirror this conversion there as well.If the APIs are intentionally asymmetric, adding a short comment around this call noting the expected dtypes/devices for these scale parameters would reduce future confusion.
365-366: Post‑xqaCUDA synchronize is appropriate to avoid async read issuesAdding
torch.cuda.synchronize()afterxqaensures all device work has completed before the test inspectsoutput, which is important if the extension uses non‑default streams or otherwise doesn’t synchronize with PyTorch’s usual tensors-to-host transitions.If you’ve seen similar flakiness in
test_xqa_mla, consider bracketing thexqa_mlacall with the same pattern for symmetry and robustness.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/attention/test_xqa.py(4 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
|
/bot run |
|
[FAILED] Pipeline #38955732: 14/18 passed |
tests/attention/test_xqa.py
Outdated
| np.random.seed(seed) | ||
| torch.backends.cudnn.deterministic = True | ||
| torch.backends.cudnn.benchmark = False | ||
| torch.cuda.random.manual_seed_all(seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In recent pytorch versions,torch.manual_seed(seed) should cover the semantics of torch.cuda.random.manual_seed_all and there is no need to set gpu seed explicitly: https://docs.pytorch.org/docs/stable/notes/randomness.html#pytorch-random-number-generator
Signed-off-by: Qidi Sang <[email protected]>
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
csrc/xqa/mha.cu(1 hunks)csrc/xqa/mha_sm90.cu(1 hunks)csrc/xqa/mla_sm120.cu(1 hunks)tests/attention/test_xqa.py(6 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (5)
csrc/xqa/mha.cu (1)
1313-1314: LGTM: Pointer dereference replaced with array indexing.The change from
*qScalePtrtoqScalePtr[0](and similarly forkvScalePtr) is semantically equivalent and generates identical code. The null checks and fallback logic are correctly preserved.csrc/xqa/mha_sm90.cu (1)
643-644: LGTM: Consistent refactor across kernel files.The change matches the pattern in
mha.cu, maintaining consistency across the codebase. The transformation from pointer dereference to array indexing is correct and preserves all null-checking logic.csrc/xqa/mla_sm120.cu (1)
1567-1568: LGTM – Pointer access changed to array indexing for consistency.The change from dereferencing (
*qScalePtr,*kvScalePtr) to array indexing (qScalePtr[0],kvScalePtr[0]) is semantically equivalent and aligns with similar updates in related files.tests/attention/test_xqa.py (2)
11-11: LGTM – Default seed updated.The default seed change from 42 to 0 is cosmetic since all test invocations explicitly pass the seed value.
364-365: Post-call synchronization added for xqa test.The
torch.cuda.synchronize()after the xqa call ensures kernel completion before validation. However, the xqa_mla test is missing this post-call synchronization as well (see earlier comment).Ensure the MLA test also includes post-call synchronization:
# Add after the xqa_mla call (after line 579) torch.cuda.synchronize()
| generator = torch.Generator(device="cuda") | ||
| generator.manual_seed(42) | ||
| indices = torch.randperm(flattened.numel(), generator=generator, device="cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify determinism in the MLA test.
The xqa test now uses a seeded generator for torch.randperm (good for determinism), but the xqa_mla test at line 522 still uses the default CUDA RNG without an explicit generator. This inconsistency could lead to non-deterministic behavior in the MLA test.
🏁 Script executed:
#!/bin/bash
# Verify that xqa_mla test uses deterministic RNG for randperm
rg -n -A2 -B2 'def test_xqa_mla' tests/attention/test_xqa.py
rg -n 'torch\.randperm' tests/attention/test_xqa.pyLength of output: 322
Add seeded generator to test_xqa_mla for deterministic RNG.
The xqa test uses a seeded generator for torch.randperm (line 273), but the xqa_mla test at line 522 still calls torch.randperm without an explicit generator, falling back to the default CUDA RNG. Apply the same generator pattern to test_xqa_mla to ensure deterministic behavior across both tests.
🤖 Prompt for AI Agents
In tests/attention/test_xqa.py around lines 271-273 and at the xqa_mla call near
line 522, the CUDA RNG is not consistently seeded; create a seeded CUDA
generator (generator = torch.Generator(device="cuda");
generator.manual_seed(42)) and pass it into torch.randperm as
generator=generator (keeping device="cuda") in the xqa_mla test so both tests
use the same deterministic RNG source.
tests/attention/test_xqa.py
Outdated
| torch.cuda.synchronize() | ||
| semaphores.zero_() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: MLA test missing synchronization.
The xqa test now includes torch.cuda.synchronize() and semaphores.zero_() before the kernel call—critical additions for ensuring proper ordering and clean state. However, the xqa_mla test (starting line 565) does not include these synchronization calls. Given that this PR aims to fix flaky xqa tests, the missing synchronization in the MLA test is a significant oversight that could cause flakiness.
Apply similar synchronization to the MLA test:
# Add before line 565 (before xqa_mla call)
torch.cuda.synchronize()
semaphores.zero_()🤖 Prompt for AI Agents
In tests/attention/test_xqa.py around lines 340-342 and specifically for the
xqa_mla test starting at line 565, the MLA variant is missing the GPU
synchronization and semaphore reset that were added for the xqa test; before
calling xqa_mla at ~line 565 add a torch.cuda.synchronize() call followed by
semaphores.zero_() (using the same semaphores variable used elsewhere) to ensure
proper ordering and a clean semaphore state before launching the kernel.
| q_scale=torch.tensor(q_scale, device="cuda"), | ||
| kv_scale=torch.tensor(kv_cache_scale, device="cuda"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: MLA test not updated with tensor scales.
The xqa test now passes q_scale and kv_scale as CUDA tensors, aligning with the kernel changes (array indexing in mla_sm120.cu). However, the xqa_mla test at lines 575-576 still passes these as Python scalars. This inconsistency could cause runtime errors or incorrect behavior in the MLA path.
Update the xqa_mla test to use tensor scales:
# Update lines 575-576 in xqa_mla call
q_scale=torch.tensor(q_scale, device="cuda"),
kv_scale=torch.tensor(kv_cache_scale, device="cuda"),🤖 Prompt for AI Agents
In tests/attention/test_xqa.py around lines 575 to 576, the xqa_mla test still
passes q_scale and kv_scale as Python scalars while the rest of the tests (and
kernel changes) expect CUDA tensors; update the xqa_mla call to wrap both scales
with torch.tensor(..., device="cuda") so q_scale and kv_scale are passed as CUDA
tensors (matching the change at lines 355-356 and preventing MLA path
runtime/type errors).
|
[FAILED] Pipeline #38985166: 13/18 passed |
Signed-off-by: Qidi Sang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/attention/test_xqa.py (1)
34-120: Address Ruff warning for unused unpacked shape variables inref_attentionAt Line 60,
batch_size,nb_k_heads, andhead_grp_sizeare unpacked fromq.shapebut never used, which Ruff flags (RUF059). This is likely what’s breaking the Ruff job in CI.You can keep the shape annotation while silencing the warning by prefixing with underscores:
- batch_size, nb_k_heads, head_grp_size, _ = q.shape + _batch_size, _nb_k_heads, _head_grp_size, _ = q.shapeAlternatively, you can drop the unpack entirely if you don’t need it.
♻️ Duplicate comments (1)
tests/attention/test_xqa.py (1)
476-599: Alignxqa_mlascale arguments with the tensor‑based APIIn
test_xqa_mlayou still passq_scaleandkv_cache_scaleas Python scalars:xqa_mla( ... q_scale=q_scale, kv_scale=kv_cache_scale, ... )Whereas
test_xqanow passes these as CUDA tensors:q_scale=torch.tensor(q_scale, device="cuda"), kv_scale=torch.tensor(kv_cache_scale, device="cuda"),Given the kernel and API changes elsewhere, this inconsistency is likely to exercise a different path (or fail at runtime) for MLA versus non‑MLA, and it echoes a prior review comment on this test. I’d strongly suggest making MLA match the tensor convention:
- xqa_mla( - q_heads.to(torch.float8_e4m3fn), - cache_k_heads.to(torch.float8_e4m3fn), - cache_v_heads.to(torch.float8_e4m3fn), - page_list_arg, - seq_len_list, - output, - scratch_buf, - semaphores, - tokens_per_page, - q_scale=q_scale, - kv_scale=kv_cache_scale, - sm_count=sm_count, - enable_pdl=enable_pdl, - ) + xqa_mla( + q_heads.to(torch.float8_e4m3fn), + cache_k_heads.to(torch.float8_e4m3fn), + cache_v_heads.to(torch.float8_e4m3fn), + page_list_arg, + seq_len_list, + output, + scratch_buf, + semaphores, + tokens_per_page, + q_scale=torch.tensor(q_scale, device="cuda"), + kv_scale=torch.tensor(kv_cache_scale, device="cuda"), + sm_count=sm_count, + enable_pdl=enable_pdl, + )This keeps the MLA test in lockstep with the main XQA test and the underlying kernel expectations.
🧹 Nitpick comments (2)
tests/attention/test_xqa.py (2)
11-17: Simplifyset_random_seedto rely ontorch.manual_seedonlyYou’re calling
torch.manual_seed,torch.cuda.manual_seed, andtorch.cuda.manual_seed_allfor the sameseed. On recent PyTorch versionstorch.manual_seed(seed)already seeds CPU and all CUDA RNGs, so the extra CUDA seeding is redundant and makes the behavior slightly harder to reason about.You can simplify to a single call:
-def set_random_seed(seed=0): - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - np.random.seed(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False +def set_random_seed(seed: int = 0) -> None: + # Seeds CPU and all CUDA RNGs on recent PyTorch versions + torch.manual_seed(seed) + np.random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False
259-263: Good use of a dedicated CUDAGeneratorfor page shufflingUsing a local
torch.Generator(device="cuda")with an explicit seed fortorch.randpermdecouples the page shuffling from the global RNG state and makes this part of the test reproducible even if other code touches the global generator.If you prefer, you could factor the
42into a small module‑level constant (e.g.,PAGE_SHUFFLE_SEED = 42) to avoid a magic number, but the current behavior is otherwise sound.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/attention/test_xqa.py(8 hunks)
🧰 Additional context used
🪛 Ruff (0.14.5)
tests/attention/test_xqa.py
60-60: Unpacked variable batch_size is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
60-60: Unpacked variable nb_k_heads is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
60-60: Unpacked variable head_grp_size is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
tests/attention/test_xqa.py (2)
330-455: Batched reference path and tolerance check intest_xqalook consistentThe new sequence reconstruction and batched
ref_attentioncomparison fortest_xqalook dimensionally and numerically consistent:
- Reconstructing
batch_k_cache/batch_v_cachefrompage_list_argcovers exactlymax_seq_lentokens (num_pages * tokens_per_page == max_seq_len).- Reshaping
q_heads/outputto[batch_size, nb_k_heads, head_grp_size, dim]matches the documentedref_attentioninterface.- The tolerance logic (
atol/rtoldepending on FP8 KV and FP8 output) plus apass_ratiothreshold is a reasonable way to reduce per‑element flakiness while still catching systematic issues.No functional issues stand out in this block.
601-685: MLA batched reconstruction and reference comparison look correctThe MLA‑specific reconstruction and batched
ref_attentioncomparison largely mirror the non‑MLA path and look coherent:
batch_k_cache/batch_v_cacheshapes match[batch_size, nb_k_heads, max_seq_len, valid_elems_per_head_qk], withref_attentiontruncating V tovalid_elems_per_head_vviavalid_elems_per_v_head.- The
q_scale=q_scale * math.sqrt(576)adjustment correctly producesqk_scale = q_scale * kv_cache_scaleinsideref_attentionforvalid_elems_per_head_qk=576.- The final reshape of
outputto[batch_size, nb_k_heads, head_grp_size, valid_elems_per_head_v]and the 0.95 pass‑ratio threshold are consistent with the MLA tolerances you’ve chosen.Overall this block looks sound from a correctness and determinism standpoint.
|
/bot run |
Signed-off-by: Qidi Sang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tests/attention/test_xqa.py (1)
453-680: MLA test structure is sound; double-check q_scale/kv_scale API consistency and scalingThe MLA test mirrors the non‑MLA path nicely — same page shuffling, zeroing, batched reconstruction, and batchwise tolerance check — and the dimension handling for Q/K (576) vs V (512) looks correct. Two points worth verifying:
q_scale / kv_scale API consistency with
xqa_mla
- In
test_xqa, you now passq_scaleandkv_scaleas 1‑element CUDA tensors intoxqa.- In
test_xqa_mla, you still pass them as Python floats intoxqa_mla.
If the underlyingxqa_mlabinding has been updated to expect device tensors (similar toxqa), this mismatch could cause runtime issues or subtle differences vs the reference. If not, it’s still a bit surprising for the two APIs to diverge. Consider either:- updating the MLA call to
q_scale=torch.tensor(q_scale, device="cuda")andkv_scale=torch.tensor(kv_cache_scale, device="cuda"), or- documenting that
xqa_mlaintentionally still accepts scalars whilexqatakes tensors.This concern overlaps with an earlier review note about MLA scale handling, so please treat this as a follow‑up sanity check rather than a new issue.
Reference scaling factor for MLA
In the reference call you passq_scale=q_scale * math.sqrt(576)whileref_attentioncomputes
qk_scale = q_scale * kv_scale / sqrt(valid_elems_per_head)withvalid_elems_per_head=576.
That reduces toqk_scale = q_scale * kv_scale, which matches a kernel that usesqScale * kvCacheScalewithout an internal1/sqrt(d)factor. This is plausible, but it’s subtle; if you continue to see MLA‑specific discrepancies, this scaling is a good place to double‑check against the kernel’s definition.Everything else in this test — reconstruction, shapes, tolerances (
atol = rtol = 0.05, ≥95% pass ratio), and the use ofvalid_elems_per_v_headto slice V down to 512 — looks coherent and numerically reasonable for an FP8 MLA path.
🧹 Nitpick comments (1)
tests/attention/test_xqa.py (1)
34-120: Tighten ref_attention API docs and unused shape unpackThe batched reference implementation looks consistent (shapes, sliding window masking, optional attention sinks, and separate K/V head dims all line up), but two small cleanups would help:
- Line 60:
batch_size, nb_k_heads, head_grp_size, _ = q.shape— none of these locals are used later. Either drop the unpack entirely or rename to_batch_size,_nb_k_heads,_head_grp_sizeto satisfy linters like Ruff (RUF059).- Docstring says
seq_lencan be a scalar or[batch_size]tensor, but the implementation treats it as a single scalar (:seq_lenslicing and scalar comparisons). If per-batch lengths aren’t actually needed here, consider updating the docstring to just describe the scalar case to avoid confusion.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/attention/test_xqa.py(8 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_xqa.py (3)
flashinfer/fused_moe/utils.py (1)
_(157-163)csrc/xqa/mla_sm120.cu (3)
x(455-457)x(459-461)x(1032-1037)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
enable_pdl(220-220)
🪛 Ruff (0.14.5)
tests/attention/test_xqa.py
60-60: Unpacked variable batch_size is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
60-60: Unpacked variable nb_k_heads is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
60-60: Unpacked variable head_grp_size is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
tests/attention/test_xqa.py (2)
11-17: Seed helper centralizes determinism settings appropriately
set_random_seednow cleanly centralizes PyTorch and NumPy seeding plus cuDNN determinism flags. Given both tests call this at the top, this is a good place to control reproducibility; no issues from a correctness standpoint.
123-451: xqa test refactor (seeding, page shuffling, batched KV reconstruction, and batch validation) looks correctOverall this test path looks solid and should help reduce flakiness:
- Seeding:
set_random_seed(0)at the start plus fixed-shape initializations makes the stochastic parts reproducible across runs.- Page shuffling: using
page_list_argwith a CUDArandpermon the flattened indices, then reshaping back, yields a valid 1‑to‑1 page permutation per batch. Combined with the subsequent zeroing logic, the mapping between logical sequence positions and pages is correct for both full and partially filled last pages.- Zeroing unused cache positions: the
start_page,token_start_in_first_page, andpages_to_zerologic correctly zeros:
- only the tail of the first “partial” page, and
- all subsequent pages,
for both NHD and HND layouts, ensuring padded tokens don’t leak into the reference.- Batched cache reconstruction:
batch_k_cache/batch_v_cachereconstruction viapages = page_list_arg[req, :num_pages]and reshape to contiguous[num_pages * tokens_per_page, head_dim]matches howseq_lenis later interpreted inref_attention(only the firstseq_lenpositions are actually consumed).- Q reshaping:
q_heads.squeeze(1).reshape(batch_size, nb_k_heads, head_grp_size, valid_elems_per_head)is dimensionally consistent withnb_q_heads = nb_k_heads * head_grp_sizeand matches the reference’s expectations.- Scale and validation:
- Passing
q_scale/kv_cache_scaleas 1‑element CUDA tensors intoxqais consistent with pointer‑based scale handling on the kernel side.- The batch-level comparison (
diff_abs/diff_rel, combined tolerance mask, and ≥99% pass ratio) is a good single-shot criterion that should be robust against small FP8 / mixed-precision noise while still sensitive to real regressions.I don’t see correctness issues in this refactored test path; it’s a substantial improvement in clarity and robustness over per‑element checks.
Signed-off-by: Qidi Sang <[email protected]>
|
[SUCCESS] Pipeline #39054334: 15/18 passed |
|
/bot run |
|
[FAILED] Pipeline #39089850: 15/18 passed |
yzh119
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, per discussion with @qsang-nv , we found there will be UT errors when seq_len=514 on spark, current workaround is to report xfail, should investigate later.
|
|
||
| float const qScaleValue = qScalePtr != nullptr ? *qScalePtr : qScale; | ||
| float const kvCacheScaleValue = kvScalePtr != nullptr ? *kvScalePtr : kvCacheScale; | ||
| float const qScaleValue = qScalePtr != nullptr ? qScalePtr[0] : qScale; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think these changes matter but wouldn't hurt as well.
📌 Description
WIP. Do not merge, see if this could fix xqa flaky test.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Tests
Bug Fixes
✏️ Tip: You can customize this high-level summary in your review settings.