-
Notifications
You must be signed in to change notification settings - Fork 584
fix flaky xqa test #2126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix flaky xqa test #2126
Changes from 2 commits
119487e
7c33f61
7504aa2
68f105b
da8975a
ea51957
3915e37
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,7 +8,7 @@ | |
| from flashinfer.utils import get_compute_capability | ||
|
|
||
|
|
||
| def set_random_seed(seed=42): | ||
| def set_random_seed(seed=0): | ||
| torch.manual_seed(seed) | ||
| torch.cuda.manual_seed(seed) | ||
| torch.cuda.manual_seed_all(seed) | ||
|
|
@@ -173,7 +173,7 @@ def test_xqa( | |
| q_scale, | ||
| use_fp8_output, | ||
| ): | ||
| set_random_seed(42) | ||
| set_random_seed(0) | ||
|
|
||
| nb_q_heads = nb_k_heads * head_grp_size | ||
|
|
||
|
|
@@ -268,7 +268,9 @@ def test_xqa( | |
|
|
||
| # Shuffle page indices | ||
| flattened = page_list_arg.flatten() | ||
| indices = torch.randperm(flattened.numel(), device="cuda") | ||
| generator = torch.Generator(device="cuda") | ||
| generator.manual_seed(42) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| indices = torch.randperm(flattened.numel(), generator=generator, device="cuda") | ||
|
Comment on lines
+276
to
+278
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. π§© Analysis chainVerify determinism in the MLA test. The xqa test now uses a seeded generator for π Script executed: #!/bin/bash
# Verify that xqa_mla test uses deterministic RNG for randperm
rg -n -A2 -B2 'def test_xqa_mla' tests/attention/test_xqa.py
rg -n 'torch\.randperm' tests/attention/test_xqa.pyLength of output: 322 Add seeded generator to test_xqa_mla for deterministic RNG. The xqa test uses a seeded generator for π€ Prompt for AI Agents |
||
| shuffled_flat = flattened[indices] | ||
| page_list_arg = shuffled_flat.view(batch_size, nb_pages_per_seq) | ||
|
|
||
|
|
@@ -335,6 +337,9 @@ def test_xqa( | |
|
|
||
| rcp_out_scale = 4.0 if use_fp8_output else 1.0 | ||
|
|
||
| torch.cuda.synchronize() | ||
| semaphores.zero_() | ||
|
|
||
|
||
| xqa( | ||
| q_heads, | ||
| cache_k_heads.to(torch.float8_e4m3fn) if fp8_kv_cache else cache_k_heads, | ||
|
|
@@ -347,15 +352,17 @@ def test_xqa( | |
| nb_k_heads, | ||
| tokens_per_page, | ||
| sinks=attention_sinks, | ||
| q_scale=q_scale, | ||
| kv_scale=kv_cache_scale, | ||
| q_scale=torch.tensor(q_scale, device="cuda"), | ||
| kv_scale=torch.tensor(kv_cache_scale, device="cuda"), | ||
|
Comment on lines
+357
to
+358
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical: MLA test not updated with tensor scales. The xqa test now passes Update the xqa_mla test to use tensor scales: # Update lines 575-576 in xqa_mla call
q_scale=torch.tensor(q_scale, device="cuda"),
kv_scale=torch.tensor(kv_cache_scale, device="cuda"),π€ Prompt for AI Agents |
||
| sliding_win_size=sliding_win_size, | ||
| kv_layout=kv_layout, | ||
| sm_count=sm_count, | ||
| enable_pdl=enable_pdl, | ||
| rcp_out_scale=rcp_out_scale, | ||
| ) | ||
|
|
||
| torch.cuda.synchronize() | ||
|
|
||
| for req in range(batch_size): | ||
| for b in range(beam_width): | ||
| for idx_k_head in range(nb_k_heads): | ||
|
|
@@ -446,7 +453,7 @@ def test_xqa_mla( | |
| q_scale, | ||
| enable_pdl, | ||
| ): | ||
| set_random_seed(42) | ||
| set_random_seed(0) | ||
|
|
||
| # MLA specific constants (fixed, not parameterized) | ||
| nb_k_heads = 1 # MLA only supports 1 K head | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think these changes matter but wouldn't hurt as well.