Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/xqa/mha.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1310,8 +1310,8 @@ CUBIN_EXPORT __global__
uint32_t kv_stride_page, uint32_t kv_stride_token, uint32_t kv_stride_head,
uint32_t* __restrict__ semaphores = nullptr, void* __restrict__ scratch = nullptr) {

float const qScaleValue = qScalePtr != nullptr ? *qScalePtr : qScale;
float const kvCacheScaleValue = kvScalePtr != nullptr ? *kvScalePtr : kvCacheScale;
float const qScaleValue = qScalePtr != nullptr ? qScalePtr[0] : qScale;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these changes matter but wouldn't hurt as well.

float const kvCacheScaleValue = kvScalePtr != nullptr ? kvScalePtr[0] : kvCacheScale;
assert(allowMultiBlockMode || gridDim.x == 1);
bool const isMultiBlock = allowMultiBlockMode && (gridDim.x != 1);
uint32_t const nbSubSeqPerSeq = allowMultiBlockMode ? gridDim.x : 1;
Expand Down
4 changes: 2 additions & 2 deletions csrc/xqa/mha_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -640,8 +640,8 @@ __launch_bounds__(128 * 3)
uint32_t* __restrict__ const semaphores =
nullptr, // [nbReq][nbKHeads][divUp(specDecParams.qSeqLen, inputTokensPerCta)]
void* __restrict__ const scratch = nullptr) {
float const qScaleValue = qScalePtr != nullptr ? *qScalePtr : qScale;
float const kvCacheScaleValue = kvScalePtr != nullptr ? *kvScalePtr : kvCacheScale;
float const qScaleValue = qScalePtr != nullptr ? qScalePtr[0] : qScale;
float const kvCacheScaleValue = kvScalePtr != nullptr ? kvScalePtr[0] : kvCacheScale;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) && \
(IS_SUPPORTED_F16_CASE || CACHE_ELEM_ENUM == 2) && BEAM_WIDTH == 1
uint32_t const idxReq = blockIdx.z / nbKHeads;
Expand Down
4 changes: 2 additions & 2 deletions csrc/xqa/mla_sm120.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1564,8 +1564,8 @@ __launch_bounds__(32 * 4 * 3, 1) __cluster_dims__(cgaSize, 1, 1) void kernel_mha
PartialResult* __restrict__ const partialResults =
nullptr) // [totalNbInputTokens][maxNbSubSeq]
{
float const qScaleValue = qScalePtr != nullptr ? *qScalePtr : qScale;
float const kvCacheScaleValue = kvScalePtr != nullptr ? *kvScalePtr : kvCacheScale;
float const qScaleValue = qScalePtr != nullptr ? qScalePtr[0] : qScale;
float const kvCacheScaleValue = kvScalePtr != nullptr ? kvScalePtr[0] : kvCacheScale;
assert(blockDim.x == 32 * 12 && blockDim.y == 1 && blockDim.z == 1);
extern __shared__ char smemBuf[];
uint32_t const warpRank = makeWarpUniform(this_warp(), threadIdx.x / warp_size);
Expand Down
19 changes: 13 additions & 6 deletions tests/attention/test_xqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from flashinfer.utils import get_compute_capability


def set_random_seed(seed=42):
def set_random_seed(seed=0):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_xqa(
q_scale,
use_fp8_output,
):
set_random_seed(42)
set_random_seed(0)

nb_q_heads = nb_k_heads * head_grp_size

Expand Down Expand Up @@ -268,7 +268,9 @@ def test_xqa(

# Shuffle page indices
flattened = page_list_arg.flatten()
indices = torch.randperm(flattened.numel(), device="cuda")
generator = torch.Generator(device="cuda")
generator.manual_seed(42)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The seed 42 is hardcoded here. It's also used in the set_random_seed(42) call at the beginning of this test on line 177. To improve maintainability and avoid magic numbers, consider defining a constant for the seed (e.g., SEED = 42) at the module level and using it in both places.

indices = torch.randperm(flattened.numel(), generator=generator, device="cuda")
Comment on lines +276 to +278
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

🧩 Analysis chain

Verify determinism in the MLA test.

The xqa test now uses a seeded generator for torch.randperm (good for determinism), but the xqa_mla test at line 522 still uses the default CUDA RNG without an explicit generator. This inconsistency could lead to non-deterministic behavior in the MLA test.


🏁 Script executed:

#!/bin/bash
# Verify that xqa_mla test uses deterministic RNG for randperm
rg -n -A2 -B2 'def test_xqa_mla' tests/attention/test_xqa.py
rg -n 'torch\.randperm' tests/attention/test_xqa.py

Length of output: 322


Add seeded generator to test_xqa_mla for deterministic RNG.

The xqa test uses a seeded generator for torch.randperm (line 273), but the xqa_mla test at line 522 still calls torch.randperm without an explicit generator, falling back to the default CUDA RNG. Apply the same generator pattern to test_xqa_mla to ensure deterministic behavior across both tests.

πŸ€– Prompt for AI Agents
In tests/attention/test_xqa.py around lines 271-273 and at the xqa_mla call near
line 522, the CUDA RNG is not consistently seeded; create a seeded CUDA
generator (generator = torch.Generator(device="cuda");
generator.manual_seed(42)) and pass it into torch.randperm as
generator=generator (keeping device="cuda") in the xqa_mla test so both tests
use the same deterministic RNG source.

shuffled_flat = flattened[indices]
page_list_arg = shuffled_flat.view(batch_size, nb_pages_per_seq)

Expand Down Expand Up @@ -335,6 +337,9 @@ def test_xqa(

rcp_out_scale = 4.0 if use_fp8_output else 1.0

torch.cuda.synchronize()
semaphores.zero_()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Critical: MLA test missing synchronization.

The xqa test now includes torch.cuda.synchronize() and semaphores.zero_() before the kernel callβ€”critical additions for ensuring proper ordering and clean state. However, the xqa_mla test (starting line 565) does not include these synchronization calls. Given that this PR aims to fix flaky xqa tests, the missing synchronization in the MLA test is a significant oversight that could cause flakiness.

Apply similar synchronization to the MLA test:

# Add before line 565 (before xqa_mla call)
torch.cuda.synchronize()
semaphores.zero_()
πŸ€– Prompt for AI Agents
In tests/attention/test_xqa.py around lines 340-342 and specifically for the
xqa_mla test starting at line 565, the MLA variant is missing the GPU
synchronization and semaphore reset that were added for the xqa test; before
calling xqa_mla at ~line 565 add a torch.cuda.synchronize() call followed by
semaphores.zero_() (using the same semaphores variable used elsewhere) to ensure
proper ordering and a clean semaphore state before launching the kernel.

xqa(
q_heads,
cache_k_heads.to(torch.float8_e4m3fn) if fp8_kv_cache else cache_k_heads,
Expand All @@ -347,15 +352,17 @@ def test_xqa(
nb_k_heads,
tokens_per_page,
sinks=attention_sinks,
q_scale=q_scale,
kv_scale=kv_cache_scale,
q_scale=torch.tensor(q_scale, device="cuda"),
kv_scale=torch.tensor(kv_cache_scale, device="cuda"),
Comment on lines +357 to +358
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

Critical: MLA test not updated with tensor scales.

The xqa test now passes q_scale and kv_scale as CUDA tensors, aligning with the kernel changes (array indexing in mla_sm120.cu). However, the xqa_mla test at lines 575-576 still passes these as Python scalars. This inconsistency could cause runtime errors or incorrect behavior in the MLA path.

Update the xqa_mla test to use tensor scales:

# Update lines 575-576 in xqa_mla call
q_scale=torch.tensor(q_scale, device="cuda"),
kv_scale=torch.tensor(kv_cache_scale, device="cuda"),
πŸ€– Prompt for AI Agents
In tests/attention/test_xqa.py around lines 575 to 576, the xqa_mla test still
passes q_scale and kv_scale as Python scalars while the rest of the tests (and
kernel changes) expect CUDA tensors; update the xqa_mla call to wrap both scales
with torch.tensor(..., device="cuda") so q_scale and kv_scale are passed as CUDA
tensors (matching the change at lines 355-356 and preventing MLA path
runtime/type errors).

sliding_win_size=sliding_win_size,
kv_layout=kv_layout,
sm_count=sm_count,
enable_pdl=enable_pdl,
rcp_out_scale=rcp_out_scale,
)

torch.cuda.synchronize()

for req in range(batch_size):
for b in range(beam_width):
for idx_k_head in range(nb_k_heads):
Expand Down Expand Up @@ -446,7 +453,7 @@ def test_xqa_mla(
q_scale,
enable_pdl,
):
set_random_seed(42)
set_random_seed(0)

# MLA specific constants (fixed, not parameterized)
nb_k_heads = 1 # MLA only supports 1 K head
Expand Down