Skip to content

Commit efd8554

Browse files
qsang-nvyzh119
andauthored
fix flaky xqa test (#2126)
<!-- .github/pull_request_template.md --> ## 📌 Description WIP. Do not merge, see if this could fix xqa flaky test. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Default test seed changed to improve reproducibility; tests now use batched K/V handling, batched reference comparisons, expanded sequence-length cases, device-based scaling tensors, seeded shuffling, and batch-level validation with adjusted tolerances. * Over-provisioned GPU runs now skip instead of failing. * **Bug Fixes** * More consistent attention scaling and more robust GPU attention validation across batched and device-based test paths. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Qidi Sang <[email protected]> Co-authored-by: Zihao Ye <[email protected]>
1 parent ecd4ef1 commit efd8554

File tree

4 files changed

+259
-188
lines changed

4 files changed

+259
-188
lines changed

csrc/xqa/mha.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,8 +1327,8 @@ CUBIN_EXPORT __global__
13271327
uint32_t kv_stride_page, uint32_t kv_stride_token, uint32_t kv_stride_head,
13281328
uint32_t* __restrict__ semaphores = nullptr, void* __restrict__ scratch = nullptr) {
13291329

1330-
float const qScaleValue = qScalePtr != nullptr ? *qScalePtr : qScale;
1331-
float const kvCacheScaleValue = kvScalePtr != nullptr ? *kvScalePtr : kvCacheScale;
1330+
float const qScaleValue = qScalePtr != nullptr ? qScalePtr[0] : qScale;
1331+
float const kvCacheScaleValue = kvScalePtr != nullptr ? kvScalePtr[0] : kvCacheScale;
13321332
assert(allowMultiBlockMode || gridDim.x == 1);
13331333
bool const isMultiBlock = allowMultiBlockMode && (gridDim.x != 1);
13341334
uint32_t const nbSubSeqPerSeq = allowMultiBlockMode ? gridDim.x : 1;

csrc/xqa/mha_sm90.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -640,8 +640,8 @@ __launch_bounds__(128 * 3)
640640
uint32_t* __restrict__ const semaphores =
641641
nullptr, // [nbReq][nbKHeads][divUp(specDecParams.qSeqLen, inputTokensPerCta)]
642642
void* __restrict__ const scratch = nullptr) {
643-
float const qScaleValue = qScalePtr != nullptr ? *qScalePtr : qScale;
644-
float const kvCacheScaleValue = kvScalePtr != nullptr ? *kvScalePtr : kvCacheScale;
643+
float const qScaleValue = qScalePtr != nullptr ? qScalePtr[0] : qScale;
644+
float const kvCacheScaleValue = kvScalePtr != nullptr ? kvScalePtr[0] : kvCacheScale;
645645
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) && \
646646
(IS_SUPPORTED_F16_CASE || CACHE_ELEM_ENUM == 2) && BEAM_WIDTH == 1
647647
uint32_t const idxReq = blockIdx.z / nbKHeads;

csrc/xqa/mla_sm120.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,8 +1564,8 @@ __launch_bounds__(32 * 4 * 3, 1) __cluster_dims__(cgaSize, 1, 1) void kernel_mha
15641564
PartialResult* __restrict__ const partialResults =
15651565
nullptr) // [totalNbInputTokens][maxNbSubSeq]
15661566
{
1567-
float const qScaleValue = qScalePtr != nullptr ? *qScalePtr : qScale;
1568-
float const kvCacheScaleValue = kvScalePtr != nullptr ? *kvScalePtr : kvCacheScale;
1567+
float const qScaleValue = qScalePtr != nullptr ? qScalePtr[0] : qScale;
1568+
float const kvCacheScaleValue = kvScalePtr != nullptr ? kvScalePtr[0] : kvCacheScale;
15691569
assert(blockDim.x == 32 * 12 && blockDim.y == 1 && blockDim.z == 1);
15701570
extern __shared__ char smemBuf[];
15711571
uint32_t const warpRank = makeWarpUniform(this_warp(), threadIdx.x / warp_size);

0 commit comments

Comments
 (0)